Skip to content

Commit 360d1c5

Browse files
committed
refactor: Replace Array with BaseArray in multiple files for consistency
1 parent 95a9560 commit 360d1c5

60 files changed

Lines changed: 587 additions & 566 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

brainpy/_src/_delay.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ def __init__(
9292

9393
# delay data
9494
if before_t0 is not None:
95-
assert isinstance(before_t0, (int, float, bool, bm.Array, jax.Array, Callable))
95+
assert isinstance(before_t0, (int, float, bool, bm.BaseArray, jax.Array, Callable))
9696
self._before_t0 = before_t0
9797
if length > 0:
9898
self._init_data(length)
@@ -139,7 +139,7 @@ def register_entry(
139139
delay_type = 'none'
140140
elif isinstance(delay_step, int):
141141
delay_type = 'homo'
142-
elif isinstance(delay_step, (bm.Array, jax.Array, np.ndarray)):
142+
elif isinstance(delay_step, (bm.BaseArray, jax.Array, np.ndarray)):
143143
if delay_step.size == 1 and delay_step.ndim == 0:
144144
delay_type = 'homo'
145145
else:
@@ -296,7 +296,7 @@ def _init_data(self, length, batch_size: int = None):
296296
batch_axis=batch_axis)
297297
# update delay data
298298
self.data[0] = self.latest.value
299-
if isinstance(self._before_t0, (bm.Array, jax.Array, float, int, bool)):
299+
if isinstance(self._before_t0, (bm.BaseArray, jax.Array, float, int, bool)):
300300
self.data[1:] = self._before_t0
301301
elif callable(self._before_t0):
302302
self.data[1:] = self._before_t0((length,) + self.latest.shape, dtype=self.latest.dtype)

brainpy/_src/analysis/highdim/slow_points.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -340,13 +340,13 @@ def find_fps_with_gd_method(
340340
num_candidate = self._check_candidates(candidates)
341341
if not (isinstance(candidates, (bm.ndarray, jnp.ndarray, np.ndarray)) or isinstance(candidates, dict)):
342342
raise ValueError('Candidates must be instance of ArrayType or dict of ArrayType.')
343-
fixed_points = tree_map(lambda a: bm.TrainVar(a), candidates, is_leaf=lambda x: isinstance(x, bm.Array))
343+
fixed_points = tree_map(lambda a: bm.TrainVar(a), candidates, is_leaf=lambda x: isinstance(x, bm.BaseArray))
344344
f_eval_loss = self._get_f_eval_loss()
345345

346346
def f_loss():
347347
return f_eval_loss(tree_map(lambda a: bm.as_jax(a),
348348
fixed_points,
349-
is_leaf=lambda x: isinstance(x, bm.Array))).mean()
349+
is_leaf=lambda x: isinstance(x, bm.BaseArray))).mean()
350350

351351
grad_f = bm.grad(f_loss, grad_vars=fixed_points, return_value=True)
352352
optimizer.register_train_vars(fixed_points if isinstance(fixed_points, dict) else {'a': fixed_points})
@@ -390,10 +390,10 @@ def batch_train(start_i, n_batch):
390390
self._opt_losses = jnp.concatenate(opt_losses)
391391
self._losses = f_eval_loss(tree_map(lambda a: bm.as_jax(a),
392392
fixed_points,
393-
is_leaf=lambda x: isinstance(x, bm.Array)))
393+
is_leaf=lambda x: isinstance(x, bm.BaseArray)))
394394
self._fixed_points = tree_map(lambda a: bm.as_jax(a),
395395
fixed_points,
396-
is_leaf=lambda x: isinstance(x, bm.Array))
396+
is_leaf=lambda x: isinstance(x, bm.BaseArray))
397397
self._selected_ids = jnp.arange(num_candidate)
398398

399399
if isinstance(self.target, DynamicalSystem):
@@ -429,7 +429,7 @@ def find_fps_with_opt_solver(
429429
print(f"Optimizing with {opt_solver} to find fixed points:")
430430

431431
# optimizing
432-
res = f_opt(tree_map(lambda a: bm.as_jax(a), candidates, is_leaf=lambda a: isinstance(a, bm.Array)))
432+
res = f_opt(tree_map(lambda a: bm.as_jax(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray)))
433433

434434
# results
435435
valid_ids = jnp.where(res.success)[0]
@@ -562,7 +562,7 @@ def compute_jacobians(
562562
"""
563563
# check data
564564
info = np.asarray([(l.ndim, l.shape[0])
565-
for l in tree_flatten(points, is_leaf=lambda a: isinstance(a, bm.Array))[0]])
565+
for l in tree_flatten(points, is_leaf=lambda a: isinstance(a, bm.BaseArray))[0]])
566566
ndim = np.unique(info[:, 0])
567567
if len(ndim) != 1: raise ValueError(f'Get multiple dimension of the evaluated points. {ndim}')
568568
if ndim[0] == 1:

brainpy/_src/analysis/lowdim/lowdim_analyzer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -359,9 +359,9 @@ def _get_fixed_points(self, candidates, *args, num_seg=None, tol_aux=1e-7, loss_
359359
"""
360360
# candidates: xs, a vector with the length of self.resolutions[self.x_var]
361361
# args: parameters, a list/tuple of vectors
362-
candidates = candidates.value if isinstance(candidates, bm.Array) else candidates
362+
candidates = candidates.value if isinstance(candidates, bm.BaseArray) else candidates
363363
selected_ids = np.arange(len(candidates))
364-
args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args)
364+
args = tuple(a.value if isinstance(candidates, bm.BaseArray) else a for a in args)
365365
for a in args: assert len(a) == len(candidates)
366366
if num_seg is None:
367367
num_seg = len(self.resolutions[self.x_var])
@@ -557,7 +557,7 @@ def f_jacobian(*var_and_pars):
557557
return self.F_fx(*var_and_pars), self.F_fy(*var_and_pars)
558558

559559
def call(*var_and_pars):
560-
var_and_pars = tuple((vp.value if isinstance(vp, bm.Array) else vp) for vp in var_and_pars)
560+
var_and_pars = tuple((vp.value if isinstance(vp, bm.BaseArray) else vp) for vp in var_and_pars)
561561
return jnp.array(jax.jit(f_jacobian, device=self.jit_device)(*var_and_pars))
562562

563563
self.analyzed_results[C.F_jacobian] = call
@@ -879,7 +879,7 @@ def _get_fp_candidates_by_aux_rank(self, num_segments=1, num_rank=100):
879879

880880
ps = tuple(p[ids[i]: ids[i] + arg_pre_len[i]] for i, p in enumerate(P))
881881
# change the position of meshgrid values
882-
vps = tuple((v.value if isinstance(v, bm.Array) else v) for v in ((xs, ys) + ps))
882+
vps = tuple((v.value if isinstance(v, bm.BaseArray) else v) for v in ((xs, ys) + ps))
883883
mesh_values = jnp.meshgrid(*vps)
884884
mesh_values = tuple(jnp.moveaxis(m, 0, 1) for m in mesh_values)
885885
mesh_values = tuple(m.flatten() for m in mesh_values)
@@ -934,9 +934,9 @@ def _get_fixed_points(self, candidates, *args, tol_aux=1e-7,
934934

935935
# candidates: xs, a vector with the length of self.resolutions[self.x_var]
936936
# args: parameters, a list/tuple of vectors
937-
candidates = candidates.value if isinstance(candidates, bm.Array) else candidates
937+
candidates = candidates.value if isinstance(candidates, bm.BaseArray) else candidates
938938
selected_ids = np.arange(len(candidates))
939-
args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args)
939+
args = tuple(a.value if isinstance(candidates, bm.BaseArray) else a for a in args)
940940
for a in args: assert len(a) == len(candidates)
941941

942942
if self.convert_type() == C.x_by_y:

brainpy/_src/analysis/utils/function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
def f_without_jaxarray_return(f):
1818
def f2(*args, **kwargs):
1919
r = f(*args, **kwargs)
20-
return r.value if isinstance(r, bm.Array) else r
20+
return r.value if isinstance(r, bm.BaseArray) else r
2121

2222
return f2
2323

brainpy/_src/analysis/utils/optimization.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,21 @@
3434

3535

3636
def _logical_or(a, b):
37-
a = a.value if isinstance(a, bm.Array) else a
38-
b = b.value if isinstance(b, bm.Array) else b
37+
a = a.value if isinstance(a, bm.BaseArray) else a
38+
b = b.value if isinstance(b, bm.BaseArray) else b
3939
return jnp.logical_or(a, b)
4040

4141

4242
def _logical_and(a, b):
43-
a = a.value if isinstance(a, bm.Array) else a
44-
b = b.value if isinstance(b, bm.Array) else b
43+
a = a.value if isinstance(a, bm.BaseArray) else a
44+
b = b.value if isinstance(b, bm.BaseArray) else b
4545
return jnp.logical_and(a, b)
4646

4747

4848
def _where(p, a, b):
49-
p = p.value if isinstance(p, bm.Array) else p
50-
a = a.value if isinstance(a, bm.Array) else a
51-
b = b.value if isinstance(b, bm.Array) else b
49+
p = p.value if isinstance(p, bm.BaseArray) else p
50+
a = a.value if isinstance(a, bm.BaseArray) else a
51+
b = b.value if isinstance(b, bm.BaseArray) else b
5252
return jnp.where(p, a, b)
5353

5454

@@ -175,7 +175,7 @@ def get_brentq_candidates(f, xs, ys):
175175

176176
def brentq_candidates(vmap_f, *values, args=()):
177177
# change the position of meshgrid values
178-
values = tuple((v.value if isinstance(v, bm.Array) else v) for v in values)
178+
values = tuple((v.value if isinstance(v, bm.BaseArray) else v) for v in values)
179179
xs = values[0]
180180
mesh_values = jnp.meshgrid(*values)
181181
if jnp.ndim(mesh_values[0]) > 1:
@@ -348,7 +348,7 @@ def scipy_minimize_with_jax(fun, x0,
348348
def fun_wrapper(x_flat, *args):
349349
x = unravel(x_flat)
350350
r = fun(x, *args)
351-
r = r.value if isinstance(r, bm.Array) else r
351+
r = r.value if isinstance(r, bm.BaseArray) else r
352352
return float(r)
353353

354354
# Wrap the gradient in a similar manner
@@ -386,8 +386,8 @@ def roots_of_1d_by_x(f, candidates, args=()):
386386
"""Find the roots of the given function by numerical methods.
387387
"""
388388
f = f_without_jaxarray_return(f)
389-
candidates = candidates.value if isinstance(candidates, bm.Array) else candidates
390-
args = tuple(a.value if isinstance(candidates, bm.Array) else a for a in args)
389+
candidates = candidates.value if isinstance(candidates, bm.BaseArray) else candidates
390+
args = tuple(a.value if isinstance(candidates, bm.BaseArray) else a for a in args)
391391
vals = f(candidates, *args)
392392
signs = jnp.sign(vals)
393393
zero_sign_idx = jnp.where(signs == 0)[0]

brainpy/_src/analysis/utils/others.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -72,16 +72,16 @@ def check_plot_durations(plot_durations, duration, initials):
7272

7373
def get_sign(f, xs, ys):
7474
f = f_without_jaxarray_return(f)
75-
xs = xs.value if isinstance(xs, bm.Array) else xs
76-
ys = ys.value if isinstance(ys, bm.Array) else ys
75+
xs = xs.value if isinstance(xs, bm.BaseArray) else xs
76+
ys = ys.value if isinstance(ys, bm.BaseArray) else ys
7777
Y, X = jnp.meshgrid(ys, xs)
7878
return jnp.sign(f(X, Y))
7979

8080

8181
def get_sign2(f, *xyz, args=()):
8282
in_axes = tuple(range(len(xyz))) + tuple([None] * len(args))
8383
f = jax.jit(jax.vmap(f_without_jaxarray_return(f), in_axes=in_axes))
84-
xyz = tuple((v.value if isinstance(v, bm.Array) else v) for v in xyz)
84+
xyz = tuple((v.value if isinstance(v, bm.BaseArray) else v) for v in xyz)
8585
XYZ = jnp.meshgrid(*xyz)
8686
XYZ = tuple(jnp.moveaxis(v, 1, 0).flatten() for v in XYZ)
8787
shape = (len(v) for v in xyz)
@@ -116,7 +116,7 @@ def keep_unique(candidates: Union[np.ndarray, Dict[str, np.ndarray]],
116116
return candidates, keep_ids
117117
if num_fps <= 1:
118118
return candidates, keep_ids
119-
candidates = tree_map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.Array))
119+
candidates = tree_map(lambda a: np.asarray(a), candidates, is_leaf=lambda a: isinstance(a, bm.BaseArray))
120120

121121
# If point A and point B are within identical_tol of each other, and the
122122
# A is first in the list, we keep A.

brainpy/_src/checkpoints/serialization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
except ModuleNotFoundError:
3434
msgpack = None
3535

36-
from brainpy._src.math.ndarray import Array
36+
from brainpy._src.math.ndarray import BaseArray
3737
from brainpy.errors import (AlreadyExistsError,
3838
MPACheckpointingRequiredError,
3939
MPARestoreTargetRequiredError,
@@ -225,11 +225,11 @@ def _restore_list(xs, state_dict: Dict[str, Any]) -> List[Any]:
225225
return ys
226226

227227

228-
def _array_dict_state(x: Array) -> Dict[str, jax.Array]:
228+
def _array_dict_state(x: BaseArray) -> Dict[str, jax.Array]:
229229
return x.value
230230

231231

232-
def _restore_array(x, state_dict: jax.Array) -> Array:
232+
def _restore_array(x, state_dict: jax.Array) -> BaseArray:
233233
x.value = state_dict
234234
return x
235235

@@ -276,7 +276,7 @@ def _restore_namedtuple(xs, state_dict: Dict[str, Any]):
276276
return type(xs)(**fields)
277277

278278

279-
register_serialization_state(Array, _array_dict_state, _restore_array)
279+
register_serialization_state(BaseArray, _array_dict_state, _restore_array)
280280
register_serialization_state(dict, _dict_state_dict, _restore_dict)
281281
# register_serialization_state(DotDict, _dict_state_dict, _restore_dict)
282282
# register_serialization_state(Collector, _dict_state_dict, _restore_dict)

brainpy/_src/connect/custom_conn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class MatConn(TwoEndConnector):
2222
def __init__(self, conn_mat, **kwargs):
2323
super(MatConn, self).__init__(**kwargs)
2424

25-
assert isinstance(conn_mat, (np.ndarray, bm.Array, jax.Array)) and conn_mat.ndim == 2
25+
assert isinstance(conn_mat, (np.ndarray, bm.BaseArray, jax.Array)) and conn_mat.ndim == 2
2626
self.pre_num, self.post_num = conn_mat.shape
2727
self.pre_size, self.post_size = (self.pre_num,), (self.post_num,)
2828

@@ -45,8 +45,8 @@ class IJConn(TwoEndConnector):
4545
def __init__(self, i, j, **kwargs):
4646
super(IJConn, self).__init__(**kwargs)
4747

48-
assert isinstance(i, (np.ndarray, bm.Array, jnp.ndarray)) and i.ndim == 1
49-
assert isinstance(j, (np.ndarray, bm.Array, jnp.ndarray)) and j.ndim == 1
48+
assert isinstance(i, (np.ndarray, bm.BaseArray, jnp.ndarray)) and i.ndim == 1
49+
assert isinstance(j, (np.ndarray, bm.BaseArray, jnp.ndarray)) and j.ndim == 1
5050
assert i.size == j.size
5151

5252
# initialize the class via "pre_ids" and "post_ids"

brainpy/_src/delay.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __init__(
102102

103103
# delay data
104104
if init is not None:
105-
assert isinstance(init, (numbers.Number, bm.Array, jax.Array, Callable))
105+
assert isinstance(init, (numbers.Number, bm.BaseArray, jax.Array, Callable))
106106
self._init = init
107107

108108
# other info
@@ -421,7 +421,7 @@ def _init_data(self, length: int, batch_size: int = None):
421421
else:
422422
self.data._value = data
423423
# update delay data
424-
if isinstance(self._init, (bm.Array, jax.Array, numbers.Number)):
424+
if isinstance(self._init, (bm.BaseArray, jax.Array, numbers.Number)):
425425
self.data[:] = self._init
426426
elif callable(self._init):
427427
self.data[:] = self._init((length,) + self.target.shape, dtype=self.target.dtype)
@@ -534,7 +534,7 @@ def init_delay_by_return(info: Union[bm.Variable, ReturnInfo], initial_delay_dat
534534
# init
535535
if isinstance(info.data, Callable):
536536
init = info.data(shape)
537-
elif isinstance(info.data, (bm.Array, jax.Array)):
537+
elif isinstance(info.data, (bm.BaseArray, jax.Array)):
538538
init = info.data
539539
else:
540540
raise TypeError

brainpy/_src/dnn/activations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def _inplace(inp, val, inplace):
1616
if inplace:
17-
assert isinstance(inp, bm.Array), 'input must be instance of brainpy.math.Array if inplace=True'
17+
assert isinstance(inp, bm.BaseArray), 'input must be instance of brainpy.math.Array if inplace=True'
1818
inp.value = val
1919
return inp
2020
else:

0 commit comments

Comments
 (0)