Skip to content

Commit 66fd9ed

Browse files
kaushikcfdinducer
authored andcommitted
homogenize dipatching into pytato routines by characterizing into unary vs multi-ary
1 parent b24260d commit 66fd9ed

1 file changed

Lines changed: 11 additions & 36 deletions

File tree

arraycontext/impl/pytato/fake_numpy.py

Lines changed: 11 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -50,20 +50,29 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
5050
:ref:`Pytato docs <pytato:memory-layout>` for more on this.
5151
"""
5252

53-
_pt_funcs = frozenset({
53+
_pt_unary_funcs = frozenset({
5454
"sin", "cos", "tan", "arcsin", "arccos", "arctan",
5555
"sinh", "cosh", "tanh", "exp", "log", "log10",
5656
"sqrt", "abs", "isnan", "real", "imag", "conj",
5757
})
5858

59+
_pt_multi_ary_funcs = frozenset({
60+
"arctan2", "equal", "greater", "greater_equal", "less", "less_equal",
61+
"not_equal", "minimum", "maximum", "where",
62+
})
63+
5964
def _get_fake_numpy_linalg_namespace(self):
6065
return PytatoFakeNumpyLinalgNamespace(self._array_context)
6166

6267
def __getattr__(self, name):
63-
if name in self._pt_funcs:
68+
if name in self._pt_unary_funcs:
6469
from functools import partial
6570
return partial(rec_map_array_container, getattr(pt, name))
6671

72+
if name in self._pt_multi_ary_funcs:
73+
from functools import partial
74+
return partial(rec_multimap_array_container, getattr(pt, name))
75+
6776
return super().__getattr__(name)
6877

6978
# NOTE: the order of these follows the order in numpy docs
@@ -175,31 +184,10 @@ def rec_equal(x, y):
175184

176185
return rec_equal(a, b)
177186

178-
def greater(self, x, y):
179-
return rec_multimap_array_container(pt.greater, x, y)
180-
181-
def greater_equal(self, x, y):
182-
return rec_multimap_array_container(pt.greater_equal, x, y)
183-
184-
def less(self, x, y):
185-
return rec_multimap_array_container(pt.less, x, y)
186-
187-
def less_equal(self, x, y):
188-
return rec_multimap_array_container(pt.less_equal, x, y)
189-
190-
def equal(self, x, y):
191-
return rec_multimap_array_container(pt.equal, x, y)
192-
193-
def not_equal(self, x, y):
194-
return rec_multimap_array_container(pt.not_equal, x, y)
195-
196187
# }}}
197188

198189
# {{{ mathematical functions
199190

200-
def arctan2(self, y, x):
201-
return rec_multimap_array_container(pt.arctan2, y, x)
202-
203191
def sum(self, a, axis=None, dtype=None):
204192
def _pt_sum(ary):
205193
if dtype not in [ary.dtype, None]:
@@ -209,18 +197,12 @@ def _pt_sum(ary):
209197

210198
return rec_map_reduce_array_container(sum, _pt_sum, a)
211199

212-
def maximum(self, x, y):
213-
return rec_multimap_array_container(pt.maximum, x, y)
214-
215200
def amax(self, a, axis=None):
216201
return rec_map_reduce_array_container(
217202
partial(reduce, pt.maximum), partial(pt.amax, axis=axis), a)
218203

219204
max = amax
220205

221-
def minimum(self, x, y):
222-
return rec_multimap_array_container(pt.minimum, x, y)
223-
224206
def amin(self, a, axis=None):
225207
return rec_map_reduce_array_container(
226208
partial(reduce, pt.minimum), partial(pt.amin, axis=axis), a)
@@ -231,10 +213,3 @@ def absolute(self, a):
231213
return self.abs(a)
232214

233215
# }}}
234-
235-
# {{{ sorting, searching, and counting
236-
237-
def where(self, criterion, then, else_):
238-
return rec_multimap_array_container(pt.where, criterion, then, else_)
239-
240-
# }}}

0 commit comments

Comments
 (0)