@@ -984,22 +984,34 @@ def aggregate(
984984 """
985985 xp = array_api_compat .array_namespace (data , owners )
986986
987+ def add_at (x , indices , values ):
988+ for idx , val in zip (indices , values ):
989+ x [idx ] = x [idx ] + val
990+ return x
991+
987992 def bincount (x , weights = None , minlength = 0 ):
988993 if weights is None :
989994 weights = xp .ones_like (x )
990995 result = xp .zeros ((max (minlength , int (x .max ()) + 1 ),), dtype = weights .dtype )
991- xp . add . at (result , x , weights )
996+ result = add_at (result , x , weights )
992997 return result
993998
994- bin_count = bincount (owners )
999+ if hasattr (xp , "bincount" ):
1000+ bin_count = xp .bincount (owners )
1001+ else :
1002+ # for array_api_strict
1003+ bin_count = bincount (owners )
9951004 bin_count = xp .where (bin_count == 0 , xp .ones_like (bin_count ), bin_count )
9961005
9971006 if num_owner is not None and bin_count .shape [0 ] != num_owner :
9981007 difference = num_owner - bin_count .shape [0 ]
9991008 bin_count = xp .concat ([bin_count , xp .ones (difference , dtype = bin_count .dtype )])
10001009
10011010 output = xp .zeros ((bin_count .shape [0 ], data .shape [1 ]), dtype = data .dtype )
1002- xp .add .at (output , owners , data )
1011+ if hasattr (xp , "add" ) and hasattr (xp .add , "at" ):
1012+ xp .add .at (output , owners , data )
1013+ else :
1014+ output = add_at (output , owners , data )
10031015
10041016 if average :
10051017 output = (output .T / bin_count ).T
0 commit comments