Skip to content

Commit ced6de0

Browse files
committed
add '_broadcast_other'; fix min and max
1 parent 49de038 commit ced6de0

1 file changed

Lines changed: 87 additions & 7 deletions

File tree

  • src/gradient_free_optimizers/_array_backend

src/gradient_free_optimizers/_array_backend/_pure.py

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,9 @@ def _getitem_fancy(self, idx):
305305

306306
is_bool = isinstance(idx_list[0], bool)
307307

308+
if not is_bool and isinstance(idx_list[0], float):
309+
idx_list = [int(i) for i in idx_list]
310+
308311
if self._ndim == 1:
309312
if is_bool:
310313
result = [self._data[i] for i, b in enumerate(idx_list) if b]
@@ -407,17 +410,30 @@ def __str__(self):
407410
]
408411
return str(rows)
409412

413+
def _broadcast_other(self, other):
414+
"""Tile 1D other._data to match self's 2D flat layout."""
415+
if self._ndim == 2 and other._ndim == 1:
416+
ncols = self._shape[1]
417+
nrows = self._shape[0]
418+
if len(other._data) == ncols:
419+
od = other._data
420+
if isinstance(od, _array_mod.array):
421+
return od * nrows
422+
return od * nrows
423+
return other._data
424+
410425
def _binop(self, other, op):
411426
if isinstance(other, GFOArray):
427+
other_data = self._broadcast_other(other)
412428
if isinstance(self._data, _array_mod.array) and isinstance(
413-
other._data, _array_mod.array
429+
other_data, _array_mod.array
414430
):
415431
return GFOArray._from_raw(
416-
_array_mod.array(_DOUBLE, map(op, self._data, other._data)),
432+
_array_mod.array(_DOUBLE, map(op, self._data, other_data)),
417433
self._shape,
418434
)
419435
return GFOArray._from_raw(
420-
list(map(op, self._data, other._data)), self._shape
436+
list(map(op, self._data, other_data)), self._shape
421437
)
422438
if isinstance(other, int | float):
423439
if isinstance(self._data, _array_mod.array):
@@ -505,6 +521,20 @@ def __abs__(self):
505521
def __invert__(self):
506522
return GFOArray._from_raw([not x for x in self._data], self._shape)
507523

524+
def __and__(self, other):
525+
if isinstance(other, GFOArray):
526+
return GFOArray._from_raw(
527+
[a and b for a, b in zip(self._data, other._data)], self._shape
528+
)
529+
return GFOArray._from_raw([a and other for a in self._data], self._shape)
530+
531+
def __or__(self, other):
532+
if isinstance(other, GFOArray):
533+
return GFOArray._from_raw(
534+
[a or b for a, b in zip(self._data, other._data)], self._shape
535+
)
536+
return GFOArray._from_raw([a or other for a in self._data], self._shape)
537+
508538
def __matmul__(self, other):
509539
if not isinstance(other, GFOArray):
510540
other = GFOArray(other)
@@ -693,12 +723,32 @@ def var(self, axis=None, ddof=0):
693723
def min(self, axis=None):
694724
if axis is None:
695725
return _min(self._data)
696-
raise NotImplementedError("Axis-aware min not implemented")
726+
if self._ndim != 2:
727+
raise NotImplementedError("Axis-aware min requires 2D array")
728+
nrows, ncols = self._shape
729+
data = self._data
730+
if axis == 1:
731+
result = [_min(data[r * ncols : (r + 1) * ncols]) for r in range(nrows)]
732+
else:
733+
result = [
734+
_min(data[r * ncols + c] for r in range(nrows)) for c in range(ncols)
735+
]
736+
return GFOArray._from_raw(_array_mod.array(_DOUBLE, result), (len(result),))
697737

698738
def max(self, axis=None):
699739
if axis is None:
700740
return _max(self._data)
701-
raise NotImplementedError("Axis-aware max not implemented")
741+
if self._ndim != 2:
742+
raise NotImplementedError("Axis-aware max requires 2D array")
743+
nrows, ncols = self._shape
744+
data = self._data
745+
if axis == 1:
746+
result = [_max(data[r * ncols : (r + 1) * ncols]) for r in range(nrows)]
747+
else:
748+
result = [
749+
_max(data[r * ncols + c] for r in range(nrows)) for c in range(ncols)
750+
]
751+
return GFOArray._from_raw(_array_mod.array(_DOUBLE, result), (len(result),))
702752

703753
def argmax(self, axis=None):
704754
if axis is None:
@@ -724,6 +774,11 @@ def argmin(self, axis=None):
724774
return best_i
725775
raise NotImplementedError("Axis-aware argmin not implemented")
726776

777+
def argsort(self, axis=-1):
778+
data = self._data
779+
indices = sorted(range(len(data)), key=data.__getitem__)
780+
return GFOArray._from_raw(indices, (len(indices),))
781+
727782
def any(self):
728783
return _any(self._data)
729784

@@ -1896,13 +1951,26 @@ def integers(self, low, high=None, size=None, endpoint=False):
18961951
else:
18971952
high_val = high
18981953
rng = self._rng
1954+
low_arr = isinstance(low, GFOArray)
1955+
high_arr = isinstance(high_val, GFOArray)
1956+
if low_arr or high_arr:
1957+
n = len(low) if low_arr else len(high_val)
1958+
lows = [int(v) for v in low._data] if low_arr else [int(low)] * n
1959+
highs = (
1960+
[int(v) for v in high_val._data] if high_arr else [int(high_val)] * n
1961+
)
1962+
return GFOArray._from_raw(
1963+
[float(rng.randint(l, h)) for l, h in zip(lows, highs)], (n,)
1964+
)
18991965
if size is None:
1900-
return rng.randint(low, high_val)
1966+
return rng.randint(int(low), int(high_val))
19011967
n = size if isinstance(size, int) else 1
19021968
if not isinstance(size, int):
19031969
for s in size:
19041970
n *= s
1905-
return GFOArray._from_raw([rng.randint(low, high_val) for _ in range(n)], (n,))
1971+
return GFOArray._from_raw(
1972+
[rng.randint(int(low), int(high_val)) for _ in range(n)], (n,)
1973+
)
19061974

19071975
def choice(self, a, size=None, replace=True, p=None):
19081976
if isinstance(a, int):
@@ -1937,6 +2005,18 @@ def choice(self, a, size=None, replace=True, p=None):
19372005

19382006
def normal(self, loc=0.0, scale=1.0, size=None):
19392007
rng = self._rng
2008+
loc_arr = isinstance(loc, GFOArray)
2009+
scale_arr = isinstance(scale, GFOArray)
2010+
if loc_arr or scale_arr:
2011+
n = len(scale) if scale_arr else len(loc)
2012+
locs = list(loc._data) if loc_arr else [loc] * n
2013+
scales = list(scale._data) if scale_arr else [scale] * n
2014+
return GFOArray._from_raw(
2015+
_array_mod.array(
2016+
_DOUBLE, (rng.gauss(l, s) for l, s in zip(locs, scales))
2017+
),
2018+
(n,),
2019+
)
19402020
if size is None:
19412021
return rng.gauss(loc, scale)
19422022
n = size if isinstance(size, int) else 1

0 commit comments

Comments
 (0)