Skip to content

Commit 442b466

Browse files
committed
add array implementation with c extension
1 parent 5c65c44 commit 442b466

3 files changed

Lines changed: 619 additions & 9 deletions

File tree

src/gradient_free_optimizers/_array_backend/__init__.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,15 @@
22
Array backend abstraction for optional NumPy dependency.
33
44
This module provides a unified interface for array operations with automatic
5-
fallback to pure Python implementations when NumPy is not available.
5+
fallback: numpy (fastest) -> C extension (fast) -> pure Python (functional).
66
77
Usage:
88
from gradient_free_optimizers._array_backend import array, zeros, clip, rint
9-
10-
The backend automatically selects the fastest available implementation:
11-
- If NumPy is installed: uses NumPy (fast)
12-
- If not: uses pure Python GFOArray (slower but functional)
139
"""
1410

15-
# === Dependency Detection ===
16-
1711
try:
1812
import numpy
1913

20-
# Verify numpy is fully installed, not just a namespace stub
2114
_ = numpy.__version__
2215
from numpy import array as _test_array
2316

@@ -26,14 +19,24 @@
2619
except (ImportError, AttributeError):
2720
HAS_NUMPY = False
2821

22+
try:
23+
from . import _fast_ops # noqa: F401
2924

30-
# === Backend Selection ===
25+
HAS_C_EXTENSION = True
26+
except ImportError:
27+
HAS_C_EXTENSION = False
3128

3229
if HAS_NUMPY:
3330
from ._numpy import *
3431

3532
_backend_name = "numpy"
3633
ndarray = numpy.ndarray
34+
elif HAS_C_EXTENSION:
35+
from ._c_extension import *
36+
from ._pure import GFOArray
37+
38+
_backend_name = "c_extension"
39+
ndarray = GFOArray
3740
else:
3841
from ._pure import *
3942
from ._pure import GFOArray
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
"""C-accelerated array backend using the compiled _fast_ops module.
2+
3+
Wraps the pure Python GFOArray class but accelerates hot-path operations
4+
(arithmetic, math functions, reductions) through C loops. Falls back to
5+
pure Python for operations not covered by the C extension.
6+
7+
This module re-exports everything from _pure.py, then overrides the
8+
performance-critical functions with C-accelerated versions.
9+
"""
10+
11+
import array as _array_mod
12+
13+
from . import _fast_ops
14+
from ._pure import * # noqa: F401, F403
15+
from ._pure import _DOUBLE, GFOArray
16+
17+
_frombytes = _array_mod.array.frombytes
18+
19+
20+
def _c_result(raw_bytes, shape):
21+
data = _array_mod.array(_DOUBLE)
22+
_frombytes(data, raw_bytes)
23+
return GFOArray._from_raw(data, shape)
24+
25+
26+
class _CGFOArray(GFOArray):
27+
"""GFOArray subclass with C-accelerated arithmetic."""
28+
29+
def __add__(self, other):
30+
if (
31+
isinstance(other, GFOArray)
32+
and isinstance(self._data, _array_mod.array)
33+
and isinstance(other._data, _array_mod.array)
34+
):
35+
return _c_result(_fast_ops.vec_add(self._data, other._data), self._shape)
36+
if isinstance(other, int | float) and isinstance(self._data, _array_mod.array):
37+
return _c_result(
38+
_fast_ops.vec_add_scalar(self._data, float(other)),
39+
self._shape,
40+
)
41+
return super().__add__(other)
42+
43+
def __radd__(self, other):
44+
return self.__add__(other)
45+
46+
def __sub__(self, other):
47+
if (
48+
isinstance(other, GFOArray)
49+
and isinstance(self._data, _array_mod.array)
50+
and isinstance(other._data, _array_mod.array)
51+
):
52+
return _c_result(_fast_ops.vec_sub(self._data, other._data), self._shape)
53+
if isinstance(other, int | float) and isinstance(self._data, _array_mod.array):
54+
return _c_result(
55+
_fast_ops.vec_add_scalar(self._data, -float(other)),
56+
self._shape,
57+
)
58+
return super().__sub__(other)
59+
60+
def __mul__(self, other):
61+
if (
62+
isinstance(other, GFOArray)
63+
and isinstance(self._data, _array_mod.array)
64+
and isinstance(other._data, _array_mod.array)
65+
):
66+
return _c_result(_fast_ops.vec_mul(self._data, other._data), self._shape)
67+
if isinstance(other, int | float) and isinstance(self._data, _array_mod.array):
68+
return _c_result(
69+
_fast_ops.vec_mul_scalar(self._data, float(other)),
70+
self._shape,
71+
)
72+
return super().__mul__(other)
73+
74+
def __rmul__(self, other):
75+
return self.__mul__(other)
76+
77+
def __neg__(self):
78+
if isinstance(self._data, _array_mod.array):
79+
return _c_result(_fast_ops.vec_neg(self._data), self._shape)
80+
return super().__neg__()
81+
82+
def sum(self, axis=None):
83+
if axis is None and isinstance(self._data, _array_mod.array):
84+
return _fast_ops.vec_sum(self._data)
85+
return super().sum(axis=axis)
86+
87+
def argmax(self, axis=None):
88+
if axis is None and isinstance(self._data, _array_mod.array):
89+
return _fast_ops.vec_argmax(self._data)
90+
return super().argmax(axis=axis)
91+
92+
def __matmul__(self, other):
93+
if not isinstance(other, GFOArray):
94+
other = GFOArray(other)
95+
if (
96+
self._ndim == 1
97+
and other._ndim == 1
98+
and isinstance(self._data, _array_mod.array)
99+
and isinstance(other._data, _array_mod.array)
100+
):
101+
return _fast_ops.vec_dot(self._data, other._data)
102+
if (
103+
self._ndim == 2
104+
and other._ndim == 2
105+
and isinstance(self._data, _array_mod.array)
106+
and isinstance(other._data, _array_mod.array)
107+
):
108+
m, k = self._shape
109+
k2, n = other._shape
110+
if k == k2:
111+
return _c_result(
112+
_fast_ops.mat_mul(self._data, other._data, m, k, n),
113+
(m, n),
114+
)
115+
return super().__matmul__(other)
116+
117+
118+
def _c_buf(raw_bytes):
119+
data = _array_mod.array(_DOUBLE)
120+
_frombytes(data, raw_bytes)
121+
return data
122+
123+
124+
def array(data, dtype=None):
125+
"""Create a C-accelerated GFOArray."""
126+
base = GFOArray(data, dtype=dtype)
127+
return _CGFOArray._from_raw(base._data, base._shape)
128+
129+
130+
def zeros(shape, dtype=float):
131+
from ._pure import zeros as _pure_zeros
132+
133+
base = _pure_zeros(shape, dtype)
134+
if isinstance(base._data, _array_mod.array):
135+
return _CGFOArray._from_raw(base._data, base._shape)
136+
return base
137+
138+
139+
def ones(shape, dtype=float):
140+
from ._pure import ones as _pure_ones
141+
142+
base = _pure_ones(shape, dtype)
143+
if isinstance(base._data, _array_mod.array):
144+
return _CGFOArray._from_raw(base._data, base._shape)
145+
return base
146+
147+
148+
def empty(shape, dtype=float):
149+
from ._pure import empty as _pure_empty
150+
151+
base = _pure_empty(shape, dtype)
152+
if isinstance(base._data, _array_mod.array):
153+
return _CGFOArray._from_raw(base._data, base._shape)
154+
return base
155+
156+
157+
def full(shape, fill_value, dtype=None):
158+
from ._pure import full as _pure_full
159+
160+
base = _pure_full(shape, fill_value, dtype)
161+
if isinstance(base._data, _array_mod.array):
162+
return _CGFOArray._from_raw(base._data, base._shape)
163+
return base
164+
165+
166+
def exp(x):
167+
if isinstance(x, GFOArray) and isinstance(x._data, _array_mod.array):
168+
return _c_result(_fast_ops.vec_exp(x._data), x._shape)
169+
from ._pure import exp as _pure_exp
170+
171+
return _pure_exp(x)
172+
173+
174+
def log(x):
175+
if isinstance(x, GFOArray) and isinstance(x._data, _array_mod.array):
176+
return _c_result(_fast_ops.vec_log(x._data), x._shape)
177+
from ._pure import log as _pure_log
178+
179+
return _pure_log(x)
180+
181+
182+
def sqrt(x):
183+
if isinstance(x, GFOArray) and isinstance(x._data, _array_mod.array):
184+
return _c_result(_fast_ops.vec_sqrt(x._data), x._shape)
185+
from ._pure import sqrt as _pure_sqrt
186+
187+
return _pure_sqrt(x)
188+
189+
190+
def clip(x, a_min, a_max):
191+
if (
192+
isinstance(x, GFOArray)
193+
and isinstance(x._data, _array_mod.array)
194+
and isinstance(a_min, int | float)
195+
and isinstance(a_max, int | float)
196+
):
197+
return _c_result(
198+
_fast_ops.vec_clip(x._data, float(a_min), float(a_max)),
199+
x._shape,
200+
)
201+
from ._pure import clip as _pure_clip
202+
203+
return _pure_clip(x, a_min, a_max)
204+
205+
206+
def sum(x, axis=None):
207+
if (
208+
isinstance(x, GFOArray)
209+
and axis is None
210+
and isinstance(x._data, _array_mod.array)
211+
):
212+
return _fast_ops.vec_sum(x._data)
213+
from ._pure import sum as _pure_sum
214+
215+
return _pure_sum(x, axis)
216+
217+
218+
def argmax(x, axis=None):
219+
if (
220+
isinstance(x, GFOArray)
221+
and axis is None
222+
and isinstance(x._data, _array_mod.array)
223+
):
224+
return _fast_ops.vec_argmax(x._data)
225+
from ._pure import argmax as _pure_argmax
226+
227+
return _pure_argmax(x, axis)
228+
229+
230+
def dot(a, b):
231+
if (
232+
isinstance(a, GFOArray)
233+
and isinstance(b, GFOArray)
234+
and isinstance(a._data, _array_mod.array)
235+
and isinstance(b._data, _array_mod.array)
236+
and len(a._data) == len(b._data)
237+
):
238+
return _fast_ops.vec_dot(a._data, b._data)
239+
from ._pure import dot as _pure_dot
240+
241+
return _pure_dot(a, b)

0 commit comments

Comments
 (0)