Skip to content

Commit 551d842

Browse files
committed
Revert "Add docstring and parameter formatting improvements"
This reverts commit d3fdde2.
1 parent d3fdde2 commit 551d842

127 files changed

Lines changed: 60317 additions & 60254 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

brainpy/_src/_delay.py

Lines changed: 302 additions & 302 deletions
Large diffs are not rendered by default.

brainpy/_src/analysis/highdim/slow_points.py

Lines changed: 853 additions & 853 deletions
Large diffs are not rendered by default.

brainpy/_src/analysis/lowdim/lowdim_analyzer.py

Lines changed: 1048 additions & 1047 deletions
Large diffs are not rendered by default.

brainpy/_src/analysis/lowdim/lowdim_bifurcation.py

Lines changed: 623 additions & 623 deletions
Large diffs are not rendered by default.

brainpy/_src/analysis/lowdim/lowdim_phase_plane.py

Lines changed: 524 additions & 524 deletions
Large diffs are not rendered by default.

brainpy/_src/analysis/stability.py

Lines changed: 216 additions & 215 deletions
Large diffs are not rendered by default.
Lines changed: 119 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,119 +1,119 @@
1-
# -*- coding: utf-8 -*-
2-
3-
from functools import partial
4-
from typing import Union
5-
6-
import jax
7-
import jax.numpy as jnp
8-
import numpy as np
9-
from jax.tree_util import tree_flatten
10-
11-
import brainpy._src.math as bm
12-
from brainpy.tools import numba_jit
13-
14-
__all__ = [
15-
'find_indexes_of_limit_cycle_max',
16-
'euclidean_distance',
17-
'euclidean_distance_jax',
18-
]
19-
20-
21-
@numba_jit
22-
def _f1(arr, grad, tol):
23-
condition = np.logical_and(grad[:-1] * grad[1:] <= 0, grad[:-1] >= 0)
24-
indexes = np.where(condition)[0]
25-
if len(indexes) >= 2:
26-
data = arr[indexes[-2]: indexes[-1]]
27-
length = np.max(data) - np.min(data)
28-
a = arr[indexes[-2]]
29-
b = arr[indexes[-1]]
30-
# TODO: how to choose length threshold, 1e-3?
31-
if length > 1e-3 and np.abs(a - b) <= tol * length:
32-
return indexes[-2:]
33-
return np.array([-1, -1])
34-
35-
36-
def find_indexes_of_limit_cycle_max(arr, tol=0.001):
37-
grad = np.gradient(arr)
38-
return _f1(arr, grad, tol)
39-
40-
41-
@numba_jit
42-
def euclidean_distance(points: np.ndarray, num_point=None):
43-
"""Get the distance matrix.
44-
45-
Equivalent to:
46-
47-
>>> from scipy.spatial.distance import squareform, pdist
48-
>>> f = lambda points: squareform(pdist(points, metric="euclidean"))
49-
50-
Parameters::
51-
52-
points: ArrayType
53-
The points.
54-
55-
Returns::
56-
57-
dist_matrix: jnp.ndarray
58-
The distance matrix.
59-
"""
60-
61-
if isinstance(points, dict):
62-
if num_point is None:
63-
raise ValueError('Please provide num_point')
64-
indices = np.triu_indices(num_point)
65-
dist_mat = np.zeros((num_point, num_point))
66-
for idx in range(len(indices[0])):
67-
i = indices[0][idx]
68-
j = indices[1][idx]
69-
dist_mat[i, j] = np.sqrt(np.sum([np.sum((value[i] - value[j]) ** 2) for value in points.values()]))
70-
else:
71-
num_point = points.shape[0]
72-
indices = np.triu_indices(num_point)
73-
dist_mat = np.zeros((num_point, num_point))
74-
for idx in range(len(indices[0])):
75-
i = indices[0][idx]
76-
j = indices[1][idx]
77-
dist_mat[i, j] = np.linalg.norm(points[i] - points[j])
78-
dist_mat = np.maximum(dist_mat, dist_mat.T)
79-
return dist_mat
80-
81-
82-
@jax.jit
83-
@partial(jax.vmap, in_axes=[0, 0, None])
84-
def _ed(i, j, leaves):
85-
squares = jnp.asarray([((leaf[i] - leaf[j]) ** 2).sum() for leaf in leaves])
86-
return jnp.sqrt(jnp.sum(squares))
87-
88-
89-
def euclidean_distance_jax(points: Union[jnp.ndarray, bm.ndarray], num_point=None):
90-
"""Get the distance matrix.
91-
92-
Equivalent to:
93-
94-
>>> from scipy.spatial.distance import squareform, pdist
95-
>>> f = lambda points: squareform(pdist(points, metric="euclidean"))
96-
97-
Parameters::
98-
99-
points: ArrayType
100-
The points.
101-
num_point: int
102-
103-
Returns::
104-
105-
dist_matrix: ArrayType
106-
The distance matrix.
107-
"""
108-
if isinstance(points, dict):
109-
if num_point is None:
110-
raise ValueError('Please provide num_point')
111-
else:
112-
num_point = points.shape[0]
113-
indices = jnp.triu_indices(num_point)
114-
dist_mat = bm.zeros((num_point, num_point))
115-
leaves, _ = tree_flatten(points)
116-
dist_mat[indices] = _ed(*indices, leaves)
117-
dist_mat = jnp.maximum(dist_mat.value, dist_mat.value.T)
118-
return dist_mat
119-
1+
# -*- coding: utf-8 -*-
2+
3+
from functools import partial
4+
from typing import Union
5+
6+
import jax
7+
import jax.numpy as jnp
8+
import numpy as np
9+
from jax.tree_util import tree_flatten
10+
11+
import brainpy._src.math as bm
12+
from brainpy.tools import numba_jit
13+
14+
__all__ = [
15+
'find_indexes_of_limit_cycle_max',
16+
'euclidean_distance',
17+
'euclidean_distance_jax',
18+
]
19+
20+
21+
@numba_jit
22+
def _f1(arr, grad, tol):
23+
condition = np.logical_and(grad[:-1] * grad[1:] <= 0, grad[:-1] >= 0)
24+
indexes = np.where(condition)[0]
25+
if len(indexes) >= 2:
26+
data = arr[indexes[-2]: indexes[-1]]
27+
length = np.max(data) - np.min(data)
28+
a = arr[indexes[-2]]
29+
b = arr[indexes[-1]]
30+
# TODO: how to choose length threshold, 1e-3?
31+
if length > 1e-3 and np.abs(a - b) <= tol * length:
32+
return indexes[-2:]
33+
return np.array([-1, -1])
34+
35+
36+
def find_indexes_of_limit_cycle_max(arr, tol=0.001):
37+
grad = np.gradient(arr)
38+
return _f1(arr, grad, tol)
39+
40+
41+
@numba_jit
42+
def euclidean_distance(points: np.ndarray, num_point=None):
43+
"""Get the distance matrix.
44+
45+
Equivalent to:
46+
47+
>>> from scipy.spatial.distance import squareform, pdist
48+
>>> f = lambda points: squareform(pdist(points, metric="euclidean"))
49+
50+
Parameters
51+
----------
52+
points: ArrayType
53+
The points.
54+
55+
Returns
56+
-------
57+
dist_matrix: jnp.ndarray
58+
The distance matrix.
59+
"""
60+
61+
if isinstance(points, dict):
62+
if num_point is None:
63+
raise ValueError('Please provide num_point')
64+
indices = np.triu_indices(num_point)
65+
dist_mat = np.zeros((num_point, num_point))
66+
for idx in range(len(indices[0])):
67+
i = indices[0][idx]
68+
j = indices[1][idx]
69+
dist_mat[i, j] = np.sqrt(np.sum([np.sum((value[i] - value[j]) ** 2) for value in points.values()]))
70+
else:
71+
num_point = points.shape[0]
72+
indices = np.triu_indices(num_point)
73+
dist_mat = np.zeros((num_point, num_point))
74+
for idx in range(len(indices[0])):
75+
i = indices[0][idx]
76+
j = indices[1][idx]
77+
dist_mat[i, j] = np.linalg.norm(points[i] - points[j])
78+
dist_mat = np.maximum(dist_mat, dist_mat.T)
79+
return dist_mat
80+
81+
82+
@jax.jit
83+
@partial(jax.vmap, in_axes=[0, 0, None])
84+
def _ed(i, j, leaves):
85+
squares = jnp.asarray([((leaf[i] - leaf[j]) ** 2).sum() for leaf in leaves])
86+
return jnp.sqrt(jnp.sum(squares))
87+
88+
89+
def euclidean_distance_jax(points: Union[jnp.ndarray, bm.ndarray], num_point=None):
90+
"""Get the distance matrix.
91+
92+
Equivalent to:
93+
94+
>>> from scipy.spatial.distance import squareform, pdist
95+
>>> f = lambda points: squareform(pdist(points, metric="euclidean"))
96+
97+
Parameters
98+
----------
99+
points: ArrayType
100+
The points.
101+
num_point: int
102+
103+
Returns
104+
-------
105+
dist_matrix: ArrayType
106+
The distance matrix.
107+
"""
108+
if isinstance(points, dict):
109+
if num_point is None:
110+
raise ValueError('Please provide num_point')
111+
else:
112+
num_point = points.shape[0]
113+
indices = jnp.triu_indices(num_point)
114+
dist_mat = bm.zeros((num_point, num_point))
115+
leaves, _ = tree_flatten(points)
116+
dist_mat[indices] = _ed(*indices, leaves)
117+
dist_mat = jnp.maximum(dist_mat.value, dist_mat.value.T)
118+
return dist_mat
119+

0 commit comments

Comments
 (0)