@@ -50,20 +50,29 @@ class PytatoFakeNumpyNamespace(LoopyBasedFakeNumpyNamespace):
5050 :ref:`Pytato docs <pytato:memory-layout>` for more on this.
5151 """
5252
53- _pt_funcs = frozenset ({
53+ _pt_unary_funcs = frozenset ({
5454 "sin" , "cos" , "tan" , "arcsin" , "arccos" , "arctan" ,
5555 "sinh" , "cosh" , "tanh" , "exp" , "log" , "log10" ,
5656 "sqrt" , "abs" , "isnan" , "real" , "imag" , "conj" ,
5757 })
5858
59+ _pt_multi_ary_funcs = frozenset ({
60+ "arctan2" , "equal" , "greater" , "greater_equal" , "less" , "less_equal" ,
61+ "not_equal" , "minimum" , "maximum" , "where" ,
62+ })
63+
5964 def _get_fake_numpy_linalg_namespace (self ):
6065 return PytatoFakeNumpyLinalgNamespace (self ._array_context )
6166
6267 def __getattr__ (self , name ):
63- if name in self ._pt_funcs :
68+ if name in self ._pt_unary_funcs :
6469 from functools import partial
6570 return partial (rec_map_array_container , getattr (pt , name ))
6671
72+ if name in self ._pt_multi_ary_funcs :
73+ from functools import partial
74+ return partial (rec_multimap_array_container , getattr (pt , name ))
75+
6776 return super ().__getattr__ (name )
6877
6978 # NOTE: the order of these follows the order in numpy docs
@@ -175,31 +184,10 @@ def rec_equal(x, y):
175184
176185 return rec_equal (a , b )
177186
178- def greater (self , x , y ):
179- return rec_multimap_array_container (pt .greater , x , y )
180-
181- def greater_equal (self , x , y ):
182- return rec_multimap_array_container (pt .greater_equal , x , y )
183-
184- def less (self , x , y ):
185- return rec_multimap_array_container (pt .less , x , y )
186-
187- def less_equal (self , x , y ):
188- return rec_multimap_array_container (pt .less_equal , x , y )
189-
190- def equal (self , x , y ):
191- return rec_multimap_array_container (pt .equal , x , y )
192-
193- def not_equal (self , x , y ):
194- return rec_multimap_array_container (pt .not_equal , x , y )
195-
196187 # }}}
197188
198189 # {{{ mathematical functions
199190
200- def arctan2 (self , y , x ):
201- return rec_multimap_array_container (pt .arctan2 , y , x )
202-
203191 def sum (self , a , axis = None , dtype = None ):
204192 def _pt_sum (ary ):
205193 if dtype not in [ary .dtype , None ]:
@@ -209,18 +197,12 @@ def _pt_sum(ary):
209197
210198 return rec_map_reduce_array_container (sum , _pt_sum , a )
211199
212- def maximum (self , x , y ):
213- return rec_multimap_array_container (pt .maximum , x , y )
214-
215200 def amax (self , a , axis = None ):
216201 return rec_map_reduce_array_container (
217202 partial (reduce , pt .maximum ), partial (pt .amax , axis = axis ), a )
218203
219204 max = amax
220205
221- def minimum (self , x , y ):
222- return rec_multimap_array_container (pt .minimum , x , y )
223-
224206 def amin (self , a , axis = None ):
225207 return rec_map_reduce_array_container (
226208 partial (reduce , pt .minimum ), partial (pt .amin , axis = axis ), a )
@@ -231,10 +213,3 @@ def absolute(self, a):
231213 return self .abs (a )
232214
233215 # }}}
234-
235- # {{{ sorting, searching, and counting
236-
237- def where (self , criterion , then , else_ ):
238- return rec_multimap_array_container (pt .where , criterion , then , else_ )
239-
240- # }}}
0 commit comments