|
14 | 14 | from .compat_numpy import broadcast_to, expand_dims, concatenate |
15 | 15 | from .environment import get_dt, get_float |
16 | 16 | from .interoperability import as_jax |
17 | | -from .ndarray import ndarray, Array |
| 17 | +from .ndarray import ndarray, BaseArray |
18 | 18 | from .object_transform.base import BrainPyObject |
19 | 19 | from .object_transform.controls import cond |
20 | 20 | from .object_transform.variables import Variable |
|
29 | 29 |
|
30 | 30 |
|
31 | 31 | def _as_jax_array(arr): |
32 | | - return arr.value if isinstance(arr, Array) else arr |
| 32 | + return arr.value if isinstance(arr, BaseArray) else arr |
33 | 33 |
|
34 | 34 |
|
35 | 35 | class AbstractDelay(BrainPyObject): |
@@ -129,8 +129,8 @@ def __init__( |
129 | 129 | super(TimeDelay, self).__init__(name=name) |
130 | 130 |
|
131 | 131 | # shape |
132 | | - if not isinstance(delay_target, (jnp.ndarray, Array)): |
133 | | - raise ValueError(f'Must be an instance of Array or jax.numpy.ndarray. But we got {type(delay_target)}') |
| 132 | + if not isinstance(delay_target, (jnp.ndarray, BaseArray)): |
| 133 | + raise ValueError(f'Must be an instance of BaseArray or jax.numpy.ndarray. But we got {type(delay_target)}') |
134 | 134 |
|
135 | 135 | # delay_len |
136 | 136 | self.t0 = t0 |
@@ -453,7 +453,7 @@ def retrieve(self, delay_len, *indices): |
453 | 453 | # the delay data |
454 | 454 | return self.data[indices] |
455 | 455 |
|
456 | | - def update(self, value: Union[numbers.Number, Array, jax.Array] = None): |
| 456 | + def update(self, value: Union[numbers.Number, BaseArray, jax.Array] = None): |
457 | 457 | """Update delay variable with the new data. |
458 | 458 |
|
459 | 459 | Parameters |
|
0 commit comments