Skip to content

Commit 2096a1b

Browse files
committed
Refactor import statements for consistency and add braintools dependency
1 parent e492eb8 commit 2096a1b

21 files changed

Lines changed: 802 additions & 2017 deletions

brainpy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104

105105
# Part: Others #
106106
# ---------------- #
107-
from brainpy._src.visualization import (visualize as visualize)
107+
import brainpy._src.visualization as visualize
108108

109109
# Part: Deprecations #
110110
# -------------------- #
Lines changed: 6 additions & 297 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,7 @@
11
# -*- coding: utf-8 -*-
22

33

4-
import numpy as onp
5-
from jax import vmap, lax, numpy as jnp
6-
7-
from brainpy._src import math as bm
8-
from brainpy.errors import UnsupportedError
4+
import braintools
95

106
__all__ = [
117
'cross_correlation',
@@ -16,295 +12,8 @@
1612
# 'functional_connectivity_dynamics',
1713
]
1814

19-
20-
def cross_correlation(spikes, bin, dt=None, numpy=True, method='loop'):
21-
r"""Calculate cross correlation index between neurons.
22-
23-
The coherence [1]_ between two neurons i and j is measured by their
24-
cross-correlation of spike trains at zero time lag within a time bin
25-
of :math:`\Delta t = \tau`. More specifically, suppose that a long
26-
time interval T is divided into small bins of :math:`\Delta t` and
27-
that two spike trains are given by :math:`X(l)=` 0 or 1, :math:`Y(l)=` 0
28-
or 1, :math:`l=1,2, \ldots, K(T / K=\tau)`. Thus, we define a coherence
29-
measure for the pair as:
30-
31-
.. math::
32-
33-
\kappa_{i j}(\tau)=\frac{\sum_{l=1}^{K} X(l) Y(l)}
34-
{\sqrt{\sum_{l=1}^{K} X(l) \sum_{l=1}^{K} Y(l)}}
35-
36-
The population coherence measure :math:`\kappa(\tau)` is defined by the
37-
average of :math:`\kappa_{i j}(\tau)` over many pairs of neurons in the
38-
network.
39-
40-
.. note::
41-
To JIT compile this function, users should make ``bin``, ``dt``, ``numpy`` static.
42-
For example, ``partial(brainpy.measure.cross_correlation, bin=10, numpy=False)``.
43-
44-
Parameters::
45-
46-
spikes : ndarray
47-
The history of spike states of the neuron group.
48-
bin : float, int
49-
The time bin to normalize spike states.
50-
dt : float, optional
51-
The time precision.
52-
numpy: bool
53-
Whether we use numpy array as the functional output.
54-
If ``False``, this function can be JIT compiled.
55-
method: str
56-
The method to calculate all pairs of cross correlation.
57-
Supports two kinds of methods: `loop` and `vmap`.
58-
`vmap` method needs much more memory.
59-
60-
.. versionadded:: 2.2.3.4
61-
62-
Returns::
63-
64-
cc_index : float
65-
The cross correlation value which represents the synchronization index.
66-
67-
References::
68-
69-
.. [1] Wang, Xiao-Jing, and György Buzsáki. "Gamma oscillation by synaptic
70-
inhibition in a hippocampal interneuronal network model." Journal of
71-
neuroscience 16.20 (1996): 6402-6413.
72-
"""
73-
spikes = bm.as_numpy(spikes) if numpy else bm.as_jax(spikes)
74-
np = onp if numpy else jnp
75-
dt = bm.get_dt() if dt is None else dt
76-
bin_size = int(bin / dt)
77-
num_hist, num_neu = spikes.shape
78-
num_bin = int(onp.ceil(num_hist / bin_size))
79-
if num_bin * bin_size != num_hist:
80-
spikes = np.append(spikes, np.zeros((num_bin * bin_size - num_hist, num_neu)), axis=0)
81-
states = spikes.T.reshape((num_neu, num_bin, bin_size))
82-
states = jnp.asarray(np.sum(states, axis=2) > 0., dtype=jnp.float_)
83-
indices = jnp.tril_indices(num_neu, k=-1)
84-
85-
if method == 'loop':
86-
def _f(i, j):
87-
sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j]))
88-
return lax.cond(sqrt_ij == 0.,
89-
lambda _: 0.,
90-
lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij,
91-
None)
92-
93-
res = bm.for_loop(_f, operands=indices)
94-
95-
elif method == 'vmap':
96-
@vmap
97-
def _cc(i, j):
98-
sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j]))
99-
return lax.cond(sqrt_ij == 0.,
100-
lambda _: 0.,
101-
lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij,
102-
None)
103-
104-
res = _cc(*indices)
105-
else:
106-
raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".')
107-
108-
return np.mean(np.asarray(res))
109-
110-
111-
def _f_signal(signal):
112-
return jnp.mean(signal * signal) - jnp.mean(signal) ** 2
113-
114-
115-
def voltage_fluctuation(potentials, numpy=True, method='loop'):
116-
r"""Calculate neuronal synchronization via voltage variance.
117-
118-
The method comes from [1]_ [2]_ [3]_.
119-
120-
First, average over the membrane potential :math:`V`
121-
122-
.. math::
123-
124-
V(t) = \frac{1}{N} \sum_{i=1}^{N} V_i(t)
125-
126-
The variance of the time fluctuations of :math:`V(t)` is
127-
128-
.. math::
129-
130-
\sigma_V^2 = \left\langle \left[ V(t) \right]^2 \right\rangle_t -
131-
\left[ \left\langle V(t) \right\rangle_t \right]^2
132-
133-
where :math:`\left\langle \ldots \right\rangle_t = (1 / T_m) \int_0^{T_m} dt \, \ldots`
134-
denotes time-averaging over a large time, :math:`\tau_m`. After normalization
135-
of :math:`\sigma_V` to the average over the population of the single cell
136-
membrane potentials
137-
138-
.. math::
139-
140-
\sigma_{V_i}^2 = \left\langle\left[ V_i(t) \right]^2 \right\rangle_t -
141-
\left[ \left\langle V_i(t) \right\rangle_t \right]^2
142-
143-
one defines a synchrony measure, :math:`\chi (N)`, for the activity of a system
144-
of :math:`N` neurons by:
145-
146-
.. math::
147-
148-
\chi^2 \left( N \right) = \frac{\sigma_V^2}{ \frac{1}{N} \sum_{i=1}^N
149-
\sigma_{V_i}^2}
150-
151-
.. [1] Golomb, D. and Rinzel J. (1993) Dynamics of globally coupled
152-
inhibitory neurons with heterogeneity. Phys. Rev. E 48:4810-4814.
153-
.. [2] Golomb D. and Rinzel J. (1994) Clustering in globally coupled
154-
inhibitory neurons. Physica D 72:259-282.
155-
.. [3] David Golomb (2007) Neuronal synchrony measures. Scholarpedia, 2(1):1347.
156-
157-
Args:
158-
potentials: The membrane potential matrix of the neuron group.
159-
numpy: Whether we use numpy array as the functional output. If ``False``, this function can be JIT compiled.
160-
method: The method to calculate all pairs of cross correlation.
161-
Supports two kinds of methods: `loop` and `vmap`.
162-
`vmap` method will consume much more memory.
163-
164-
.. versionadded:: 2.2.3.4
165-
166-
Returns:
167-
sync_index: The synchronization index.
168-
"""
169-
170-
potentials = bm.as_jax(potentials)
171-
avg = jnp.mean(potentials, axis=1)
172-
avg_var = jnp.mean(avg * avg) - jnp.mean(avg) ** 2
173-
174-
if method == 'loop':
175-
_var = bm.for_loop(_f_signal, operands=jnp.moveaxis(potentials, 0, 1))
176-
177-
elif method == 'vmap':
178-
_var = vmap(_f_signal, in_axes=1)(potentials)
179-
else:
180-
raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".')
181-
182-
var_mean = jnp.mean(_var)
183-
r = jnp.where(var_mean == 0., 1., avg_var / var_mean)
184-
return bm.as_numpy(r) if numpy else r
185-
186-
187-
def matrix_correlation(x, y, numpy=True):
188-
"""Pearson correlation of the lower triagonal of two matrices.
189-
190-
The triangular matrix is offset by k = 1 in order to ignore the diagonal line
191-
192-
Parameters::
193-
194-
x: ndarray
195-
First matrix.
196-
y: ndarray
197-
Second matrix
198-
numpy: bool
199-
Whether we use numpy array as the functional output.
200-
If ``False``, this function can be JIT compiled.
201-
202-
Returns::
203-
204-
coef: ndarray
205-
Correlation coefficient
206-
"""
207-
208-
x = bm.as_numpy(x) if numpy else bm.as_device_array(x)
209-
y = bm.as_numpy(y) if numpy else bm.as_device_array(y)
210-
np = onp if numpy else jnp
211-
if x.ndim != 2:
212-
raise ValueError(f'Only support 2d array, but we got a array '
213-
f'with the shape of {x.shape}')
214-
if y.ndim != 2:
215-
raise ValueError(f'Only support 2d array, but we got a array '
216-
f'with the shape of {y.shape}')
217-
x = x[np.triu_indices_from(x, k=1)]
218-
y = y[np.triu_indices_from(y, k=1)]
219-
cc = np.corrcoef(x, y)[0, 1]
220-
return cc
221-
222-
223-
def functional_connectivity(activities, numpy=True):
224-
"""Functional connectivity matrix of timeseries activities.
225-
226-
Parameters::
227-
228-
activities: ndarray
229-
The multidimensional array with the shape of ``(num_time, num_sample)``.
230-
numpy: bool
231-
Whether we use numpy array as the functional output.
232-
If ``False``, this function can be JIT compiled.
233-
234-
Returns::
235-
236-
connectivity_matrix: ndarray
237-
``num_sample x num_sample`` functional connectivity matrix.
238-
"""
239-
activities = bm.as_numpy(activities) if numpy else bm.as_device_array(activities)
240-
np = onp if numpy else jnp
241-
if activities.ndim != 2:
242-
raise ValueError('Only support 2d array with shape of "(num_time, num_sample)". '
243-
f'But we got a array with the shape of {activities.shape}')
244-
fc = np.corrcoef(activities.T)
245-
return np.nan_to_num(fc)
246-
247-
248-
def functional_connectivity_dynamics(activities, window_size=30, step_size=5):
249-
"""Computes functional connectivity dynamics (FCD) matrix.
250-
251-
Parameters::
252-
253-
activities: ndarray
254-
The time series with shape of ``(num_time, num_sample)``.
255-
window_size: int
256-
Size of each rolling window in time steps, defaults to 30.
257-
step_size: int
258-
Step size between each rolling window, defaults to 5.
259-
260-
Returns::
261-
262-
fcd_matrix: ndarray
263-
FCD matrix.
264-
"""
265-
pass
266-
267-
268-
def weighted_correlation(x, y, w, numpy=True):
269-
"""Weighted Pearson correlation of two data series.
270-
271-
Parameters::
272-
273-
x: ndarray
274-
The data series 1.
275-
y: ndarray
276-
The data series 2.
277-
w: ndarray
278-
Weight vector, must have same length as x and y.
279-
numpy: bool
280-
Whether we use numpy array as the functional output.
281-
If ``False``, this function can be JIT compiled.
282-
283-
Returns::
284-
285-
corr: ndarray
286-
Weighted correlation coefficient.
287-
"""
288-
x = bm.as_numpy(x) if numpy else bm.as_device_array(x)
289-
y = bm.as_numpy(y) if numpy else bm.as_device_array(y)
290-
w = bm.as_numpy(w) if numpy else bm.as_device_array(w)
291-
np = onp if numpy else jnp
292-
293-
def _weighted_mean(x, w):
294-
"""Weighted Mean"""
295-
return np.sum(x * w) / np.sum(w)
296-
297-
def _weighted_cov(x, y, w):
298-
"""Weighted Covariance"""
299-
return np.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / np.sum(w)
300-
301-
if x.ndim != 1:
302-
raise ValueError(f'Only support 1d array, but we got a array '
303-
f'with the shape of {x.shape}')
304-
if y.ndim != 1:
305-
raise ValueError(f'Only support 1d array, but we got a array '
306-
f'with the shape of {y.shape}')
307-
if w.ndim != 1:
308-
raise ValueError(f'Only support 1d array, but we got a array '
309-
f'with the shape of {w.shape}')
310-
return _weighted_cov(x, y, w) / np.sqrt(_weighted_cov(x, x, w) * _weighted_cov(y, y, w))
15+
cross_correlation = braintools.metric.cross_correlation
16+
voltage_fluctuation = braintools.metric.voltage_fluctuation
17+
matrix_correlation = braintools.metric.matrix_correlation
18+
functional_connectivity = braintools.metric.functional_connectivity
19+
weighted_correlation = braintools.metric.weighted_correlation

0 commit comments

Comments
 (0)