diff --git a/aesara/ifelse.py b/aesara/ifelse.py index 29e240674d..fc1a606ec5 100644 --- a/aesara/ifelse.py +++ b/aesara/ifelse.py @@ -24,7 +24,7 @@ from aesara.graph.op import _NoPythonOp from aesara.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter from aesara.graph.type import HasDataType, HasShape -from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast +from aesara.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast, specify_shape if TYPE_CHECKING: @@ -254,21 +254,34 @@ def grad(self, ins, grads): # Since input true/false entries must have the same dtypes, we need to # cast the zeros to the corresponding `grads` dtypes and not the input # dtypes. - inputs_true_grad = ( - [condition] - + grads - + [ - at.basic.zeros_like(t, dtype=grads[i].dtype) - for i, t in enumerate(inputs_true_branch) - ] + # The `grads` can also have different shapes than the `inputs`, so we + # effectively assert that the shapes are preserved in each branch. + # TODO FIXME: This doesn't seem like a sufficient solution to the + # problem. + inputs_true_grads = if_true_op( + *( + [condition] + + [specify_shape(g, i.shape) for g, i in zip(grads, inputs_true_branch)] + + [ + at.basic.zeros_like(t, dtype=grads[i].dtype) + for i, t in enumerate(inputs_true_branch) + ] + ), + return_list=True, ) - inputs_false_grad = ( - [condition] - + [ - at.basic.zeros_like(f, dtype=grads[i].dtype) - for i, f in enumerate(inputs_false_branch) - ] - + grads + inputs_false_grads = if_false_op( + *( + [condition] + + [ + at.basic.zeros_like(f, dtype=grads[i].dtype) + for i, f in enumerate(inputs_false_branch) + ] + + [ + specify_shape(g, i.shape) + for g, i in zip(grads, inputs_false_branch) + ] + ), + return_list=True, ) # `condition` does affect the elements of the output so it is connected. @@ -276,11 +289,7 @@ def grad(self, ins, grads): # condition + epsilon always triggers the same branch as condition condition_grad = condition.zeros_like().astype(config.floatX) - return ( - [condition_grad] - + if_true_op(*inputs_true_grad, return_list=True) - + if_false_op(*inputs_false_grad, return_list=True) - ) + return [condition_grad] + inputs_true_grads + inputs_false_grads def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None): cond = node.inputs[0] diff --git a/aesara/sparse/sandbox/sp.py b/aesara/sparse/sandbox/sp.py index 1f95d01758..6015006848 100644 --- a/aesara/sparse/sandbox/sp.py +++ b/aesara/sparse/sandbox/sp.py @@ -181,7 +181,7 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True): # taking into account multiple # input features - col = ( + col = int( iy * inshp[2] + ix + fmapi * np.prod(inshp[1:]) ) @@ -196,13 +196,13 @@ def evaluate(inshp, kshp, strides=(1, 1), nkern=1, mode="valid", ws=True): # convert to row index of sparse matrix if ws: - row = ( + row = int( (y * outshp[1] + x) * inshp[0] * ksize + l + fmapi * ksize ) else: - row = y * outshp[1] + x + row = int(y * outshp[1] + x) # Store something at that location # in sparse matrix. The written diff --git a/aesara/tensor/elemwise.py b/aesara/tensor/elemwise.py index 34f9ea5459..8b6aada45b 100644 --- a/aesara/tensor/elemwise.py +++ b/aesara/tensor/elemwise.py @@ -258,15 +258,16 @@ def grad(self, inp, grads): (x,) = inp (gz,) = grads gz = as_tensor_variable(gz) - grad_order = ["x"] * len(x.type.broadcastable) + grad_order = ["x"] * x.type.ndim for i, v in enumerate(self.new_order): if v != "x": grad_order[v] = i + # Do not make the DimShuffle inplace as an optimization at the # canonicalization optimization phase will remove the inplace. # The inplace will be reintroduced automatically later in the graph. - if inp[0].dtype in discrete_dtypes: - return [inp[0].zeros_like(dtype=config.floatX)] + if x.dtype in discrete_dtypes: + return [x.zeros_like(dtype=config.floatX)] else: return [ DimShuffle(gz.type.broadcastable, grad_order)( @@ -542,7 +543,6 @@ def connection_pattern(self, node): return [[True for output in node.outputs] for ipt in node.inputs] def L_op(self, inputs, outs, ograds): - from aesara.tensor.math import sum as at_sum # Compute grad with respect to broadcasted input rval = self._bgrad(inputs, outs, ograds) @@ -573,18 +573,9 @@ def L_op(self, inputs, outs, ograds): if isinstance(rval[i].type, (NullType, DisconnectedType)): continue - # List of all the dimensions that are broadcastable for input[i] so - # we can sum over them - # TODO: only count dimensions that were effectively broadcasted - to_sum = [ - j - for j, bcast in enumerate(ipt.type.broadcastable) - if bcast and not outs[0].broadcastable[j] - ] - - if to_sum: - sr = at_sum(rval[i], axis=to_sum, keepdims=True) - rval[i] = sr + rval[i] = aesara.tensor.extra_ops.sum_broadcasted_dims( + rval[i], ipt, outs[0].type.shape + ) return rval diff --git a/aesara/tensor/extra_ops.py b/aesara/tensor/extra_ops.py index ca3e720339..6d6796e2f0 100644 --- a/aesara/tensor/extra_ops.py +++ b/aesara/tensor/extra_ops.py @@ -1,6 +1,6 @@ from collections.abc import Collection from functools import reduce -from typing import Iterable, Set, Tuple, Union +from typing import Iterable, Optional, Sequence, Set, Tuple, Union import numpy as np import numpy.core.numeric @@ -1669,19 +1669,11 @@ def grad(self, inputs, outputs_gradients): d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims) - # Determine the dimensions that were broadcast - _, static_shape = at.infer_static_shape(shape) - - # TODO: This needs to be performed at run-time when static shape - # information isn't available. - bcast_sums = [ - i - for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :])) - if a_s == 1 and s_s != 1 - ] - - if bcast_sums: - d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True) + # Determine the dimensions that were broadcast and sum them + static_out_shape = tuple( + s.data if isinstance(s, Constant) else None for s in shape[-a.ndim :] + ) + d_wrt_a = sum_broadcasted_dims(d_wrt_a, a, static_out_shape) return [d_wrt_a] + [ grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1) @@ -1808,6 +1800,46 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]: return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args) +def sum_broadcasted_dims( + value: TensorVariable, + inp: TensorVariable, + out_shape: Sequence[Optional[int]], +) -> TensorVariable: + """Sum dimensions in `value` that are broadcasted between `inp`'s shape and `out_shape`. + + For ambiguous cases, this builds a graph that determine whether or not + dimensions are to be summed at run-time. + + """ + dims_to_sum = () + ambiguous_dim_conds = () + + in_shape = inp.type.shape + + for i, (s1, s2) in enumerate(zip(in_shape, out_shape)): + if s1 == 1 and s2 != 1: + dims_to_sum += (i,) + elif s1 is None and s2 != 1: + ambiguous_dim_conds += ( + (i, aes.eq(at.scalar_from_tensor(inp.shape[i]), 1)), + ) + + if dims_to_sum: + value = at_sum(value, axis=dims_to_sum, keepdims=True) + + if ambiguous_dim_conds: + from aesara.ifelse import ifelse + + for i, cond in ambiguous_dim_conds: + value = ifelse( + cond, + at_sum(value, axis=i, keepdims=True), + value, + ) + + return value + + __all__ = [ "searchsorted", "cumsum", diff --git a/aesara/tensor/math.py b/aesara/tensor/math.py index 2b6724aa4b..1db8a4d6a8 100644 --- a/aesara/tensor/math.py +++ b/aesara/tensor/math.py @@ -504,7 +504,7 @@ def makeKeepDims(x, y, axis): newaxis.append(a) i = 0 new_dims = [] - for j, _ in enumerate(x.type.broadcastable): + for j in range(x.type.ndim): if j in newaxis: new_dims.append("x") else: diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 6bd514f277..59f62fe5be 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -8,8 +8,9 @@ import aesara import aesara.scalar as aes +import aesara.tensor as at import tests.unittest_tools as utt -from aesara.compile.mode import Mode +from aesara.compile.mode import Mode, get_default_mode from aesara.configdefaults import config from aesara.graph.basic import Apply, Variable from aesara.graph.fg import FunctionGraph @@ -889,6 +890,39 @@ def test_invalid_static_shape(self): ): x + y + def test_grad_sum_bcast_input_dims(self): + """Make sure broadcasted dimensions in the gradients are summed when static shape information isn't available.""" + Y = matrix("Y") + X = matrix("X") + X_grad = aesara.grad((X + Y).sum(), wrt=X) + + mode = get_default_mode().including("fast_run") + + X_grad_fn = aesara.function([X, Y], X_grad, mode=mode) + res = X_grad_fn(np.ones((1, 5)), np.ones((5, 5))) + assert np.array_equal(res, np.array([[5.0, 5.0, 5.0, 5.0, 5.0]])) + + # When the shapes are known at compile-time, the compiled graph should + # simplify + Y = tensor(np.float64, shape=(5, None), name="Y") + X = tensor(np.float64, shape=(1, 5), name="X") + X_grad = aesara.grad((X + Y).sum(), wrt=X) + + X_grad_fn = aesara.function([X, Y], X_grad, mode=mode) + res = X_grad_fn(np.ones((1, 5)), np.ones((5, 5))) + assert np.array_equal(res, np.array([[5.0, 5.0, 5.0, 5.0, 5.0]])) + + assert X_grad_fn.maker.fgraph.apply_nodes + + def test_grad_of_grad(self): + """This tests a special case in which the static shapes of a `DimShuffle` and its gradient don't match.""" + a = at.vector("a") + + out = aesara.grad((a * a).sum(), a).sum() + out = aesara.grad(out, a) + + assert out.type.shape == (None,) + def test_not_implemented_elemwise_grad(): # Regression test for unimplemented gradient in an Elemwise Op. diff --git a/tests/tensor/test_extra_ops.py b/tests/tensor/test_extra_ops.py index 9c88453420..4037276e00 100644 --- a/tests/tensor/test_extra_ops.py +++ b/tests/tensor/test_extra_ops.py @@ -1318,6 +1318,7 @@ def test_memory_leak(self): [ [lambda x: broadcast_to(x, (1,)), (1,)], [lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)], + [lambda x: broadcast_to(x, (6, 2, 5, 3)), (1,)], [lambda x: broadcast_to(x, (6, 2, 5, 3)), (5, 1)], [lambda x: broadcast_to(x, (6, 2, 1, 3)), (2, 1, 3)], ], diff --git a/tests/test_ifelse.py b/tests/test_ifelse.py index 5eda9f2fba..9c6e9fdb86 100644 --- a/tests/test_ifelse.py +++ b/tests/test_ifelse.py @@ -718,3 +718,56 @@ def test_nested(): linker = aesara.link.vm.VMLinker(lazy=True) f = function([c1, c2, x1, x2], t4, mode=Mode(linker=linker, optimizer="fast_run")) assert f(1, 0, np.array(10, dtype=x1.dtype), 0) == 20.5 + + +def test_DimShuffle_drop(): + c = scalar("c") + x = scalar("x") + y = vector("y") + + cost = ifelse(c, x.dimshuffle("x"), y).sum() + + # Sum{acc_dtype=float64} [id A] + # |if{} [id B] + # |c [id C] + # |InplaceDimShuffle{x} [id D] + # | |x [id E] + # |y [id F] + + out = aesara.grad(cost, y) + assert out.type.shape == (None,) + + out = aesara.grad(cost, x) + + # + # `DimShuffle.L_op` `inputs` + # + # x [id A] + + # + # `DimShuffle.L_op` `outputs` + # + # InplaceDimShuffle{x} [id B] + # |x [id A] + + # + # `DimShuffle.L_op` `output_grads` + # + # if{} [id C] + # |c [id D] + # |Elemwise{second} [id E] + # | |if{} [id F] + # | | |c [id D] + # | | |InplaceDimShuffle{x} [id B] + # | | |y [id G] + # | |InplaceDimShuffle{x} [id H] + # | |Elemwise{second,no_inplace} [id I] + # | |Sum{acc_dtype=float64} [id J] + # | | |if{} [id F] + # | |TensorConstant{1.0} [id K] + # |Elemwise{second,no_inplace} [id L] + # |InplaceDimShuffle{x} [id B] + # |InplaceDimShuffle{x} [id M] + # |TensorConstant{0.0} [id N] + + assert out.type.shape == ()