|
1 | 1 | from collections.abc import Collection |
2 | 2 | from functools import reduce |
3 | | -from typing import Iterable, Set, Tuple, Union |
| 3 | +from typing import Iterable, Sequence, Set, Tuple, Union |
4 | 4 |
|
5 | 5 | import numpy as np |
6 | 6 | import numpy.core.numeric |
@@ -1665,19 +1665,8 @@ def grad(self, inputs, outputs_gradients): |
1665 | 1665 |
|
1666 | 1666 | d_wrt_a = broadcast_to(dout, shape).sum(axis=new_dims) |
1667 | 1667 |
|
1668 | | - # Determine the dimensions that were broadcast |
1669 | | - _, static_shape = at.infer_static_shape(shape) |
1670 | | - |
1671 | | - # TODO: This needs to be performed at run-time when static shape |
1672 | | - # information isn't available. |
1673 | | - bcast_sums = [ |
1674 | | - i |
1675 | | - for i, (a_s, s_s) in enumerate(zip(a.type.shape, static_shape[-a.ndim :])) |
1676 | | - if a_s == 1 and s_s != 1 |
1677 | | - ] |
1678 | | - |
1679 | | - if bcast_sums: |
1680 | | - d_wrt_a = d_wrt_a.sum(axis=bcast_sums, keepdims=True) |
| 1668 | + # Determine the dimensions that were broadcast and sum them |
| 1669 | + d_wrt_a = sum_broadcastable_dims(d_wrt_a, a.shape, shape[-a.ndim :]) |
1681 | 1670 |
|
1682 | 1671 | return [d_wrt_a] + [ |
1683 | 1672 | grad_undefined(self, i, shp) for i, shp in enumerate(shape, 1) |
@@ -1804,6 +1793,33 @@ def broadcast_arrays(*args: TensorVariable) -> Tuple[TensorVariable, ...]: |
1804 | 1793 | return tuple(broadcast_to(a, broadcast_shape(*args)) for a in args) |
1805 | 1794 |
|
1806 | 1795 |
|
| 1796 | +def sum_broadcastable_dims( |
| 1797 | + value: TensorVariable, |
| 1798 | + shape_1: Sequence[Variable], |
| 1799 | + shape_2: Sequence[Variable], |
| 1800 | +) -> TensorVariable: |
| 1801 | + """Sum dimensions in `value` that are broadcasted between `shape_1` and `shape_2`.""" |
| 1802 | + from aesara.ifelse import ifelse |
| 1803 | + |
| 1804 | + for i, (s1, s2) in enumerate(zip(shape_1, shape_2)): |
| 1805 | + dummy_s1 = aes.get_scalar_type(dtype=s1.type.dtype)() |
| 1806 | + dummy_s2 = aes.get_scalar_type(dtype=s2.type.dtype)() |
| 1807 | + cond_op = Composite( |
| 1808 | + [dummy_s1, dummy_s2], |
| 1809 | + [ |
| 1810 | + aesara.scalar.and_( |
| 1811 | + aesara.scalar.eq(dummy_s1, 1), aesara.scalar.neq(dummy_s2, 1) |
| 1812 | + ) |
| 1813 | + ], |
| 1814 | + ) |
| 1815 | + value = ifelse( |
| 1816 | + cond_op(at.scalar_from_tensor(s1), at.scalar_from_tensor(s2)), |
| 1817 | + at_sum(value, axis=i, keepdims=True), |
| 1818 | + value, |
| 1819 | + ) |
| 1820 | + return value |
| 1821 | + |
| 1822 | + |
1807 | 1823 | __all__ = [ |
1808 | 1824 | "searchsorted", |
1809 | 1825 | "cumsum", |
|
0 commit comments