|
1 | 1 | # -*- coding: utf-8 -*- |
2 | 2 |
|
3 | 3 |
|
4 | | -import numpy as onp |
5 | | -from jax import vmap, lax, numpy as jnp |
6 | | - |
7 | | -from brainpy._src import math as bm |
8 | | -from brainpy.errors import UnsupportedError |
| 4 | +import braintools |
9 | 5 |
|
10 | 6 | __all__ = [ |
11 | 7 | 'cross_correlation', |
|
16 | 12 | # 'functional_connectivity_dynamics', |
17 | 13 | ] |
18 | 14 |
|
19 | | - |
20 | | -def cross_correlation(spikes, bin, dt=None, numpy=True, method='loop'): |
21 | | - r"""Calculate cross correlation index between neurons. |
22 | | -
|
23 | | - The coherence [1]_ between two neurons i and j is measured by their |
24 | | - cross-correlation of spike trains at zero time lag within a time bin |
25 | | - of :math:`\Delta t = \tau`. More specifically, suppose that a long |
26 | | - time interval T is divided into small bins of :math:`\Delta t` and |
27 | | - that two spike trains are given by :math:`X(l)=` 0 or 1, :math:`Y(l)=` 0 |
28 | | - or 1, :math:`l=1,2, \ldots, K(T / K=\tau)`. Thus, we define a coherence |
29 | | - measure for the pair as: |
30 | | -
|
31 | | - .. math:: |
32 | | -
|
33 | | - \kappa_{i j}(\tau)=\frac{\sum_{l=1}^{K} X(l) Y(l)} |
34 | | - {\sqrt{\sum_{l=1}^{K} X(l) \sum_{l=1}^{K} Y(l)}} |
35 | | -
|
36 | | - The population coherence measure :math:`\kappa(\tau)` is defined by the |
37 | | - average of :math:`\kappa_{i j}(\tau)` over many pairs of neurons in the |
38 | | - network. |
39 | | -
|
40 | | - .. note:: |
41 | | - To JIT compile this function, users should make ``bin``, ``dt``, ``numpy`` static. |
42 | | - For example, ``partial(brainpy.measure.cross_correlation, bin=10, numpy=False)``. |
43 | | -
|
44 | | - Parameters:: |
45 | | -
|
46 | | - spikes : ndarray |
47 | | - The history of spike states of the neuron group. |
48 | | - bin : float, int |
49 | | - The time bin to normalize spike states. |
50 | | - dt : float, optional |
51 | | - The time precision. |
52 | | - numpy: bool |
53 | | - Whether we use numpy array as the functional output. |
54 | | - If ``False``, this function can be JIT compiled. |
55 | | - method: str |
56 | | - The method to calculate all pairs of cross correlation. |
57 | | - Supports two kinds of methods: `loop` and `vmap`. |
58 | | - `vmap` method needs much more memory. |
59 | | -
|
60 | | - .. versionadded:: 2.2.3.4 |
61 | | -
|
62 | | - Returns:: |
63 | | -
|
64 | | - cc_index : float |
65 | | - The cross correlation value which represents the synchronization index. |
66 | | -
|
67 | | - References:: |
68 | | -
|
69 | | - .. [1] Wang, Xiao-Jing, and György Buzsáki. "Gamma oscillation by synaptic |
70 | | - inhibition in a hippocampal interneuronal network model." Journal of |
71 | | - neuroscience 16.20 (1996): 6402-6413. |
72 | | - """ |
73 | | - spikes = bm.as_numpy(spikes) if numpy else bm.as_jax(spikes) |
74 | | - np = onp if numpy else jnp |
75 | | - dt = bm.get_dt() if dt is None else dt |
76 | | - bin_size = int(bin / dt) |
77 | | - num_hist, num_neu = spikes.shape |
78 | | - num_bin = int(onp.ceil(num_hist / bin_size)) |
79 | | - if num_bin * bin_size != num_hist: |
80 | | - spikes = np.append(spikes, np.zeros((num_bin * bin_size - num_hist, num_neu)), axis=0) |
81 | | - states = spikes.T.reshape((num_neu, num_bin, bin_size)) |
82 | | - states = jnp.asarray(np.sum(states, axis=2) > 0., dtype=jnp.float_) |
83 | | - indices = jnp.tril_indices(num_neu, k=-1) |
84 | | - |
85 | | - if method == 'loop': |
86 | | - def _f(i, j): |
87 | | - sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j])) |
88 | | - return lax.cond(sqrt_ij == 0., |
89 | | - lambda _: 0., |
90 | | - lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij, |
91 | | - None) |
92 | | - |
93 | | - res = bm.for_loop(_f, operands=indices) |
94 | | - |
95 | | - elif method == 'vmap': |
96 | | - @vmap |
97 | | - def _cc(i, j): |
98 | | - sqrt_ij = jnp.sqrt(jnp.sum(states[i]) * jnp.sum(states[j])) |
99 | | - return lax.cond(sqrt_ij == 0., |
100 | | - lambda _: 0., |
101 | | - lambda _: jnp.sum(states[i] * states[j]) / sqrt_ij, |
102 | | - None) |
103 | | - |
104 | | - res = _cc(*indices) |
105 | | - else: |
106 | | - raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".') |
107 | | - |
108 | | - return np.mean(np.asarray(res)) |
109 | | - |
110 | | - |
111 | | -def _f_signal(signal): |
112 | | - return jnp.mean(signal * signal) - jnp.mean(signal) ** 2 |
113 | | - |
114 | | - |
115 | | -def voltage_fluctuation(potentials, numpy=True, method='loop'): |
116 | | - r"""Calculate neuronal synchronization via voltage variance. |
117 | | -
|
118 | | - The method comes from [1]_ [2]_ [3]_. |
119 | | -
|
120 | | - First, average over the membrane potential :math:`V` |
121 | | -
|
122 | | - .. math:: |
123 | | -
|
124 | | - V(t) = \frac{1}{N} \sum_{i=1}^{N} V_i(t) |
125 | | -
|
126 | | - The variance of the time fluctuations of :math:`V(t)` is |
127 | | -
|
128 | | - .. math:: |
129 | | -
|
130 | | - \sigma_V^2 = \left\langle \left[ V(t) \right]^2 \right\rangle_t - |
131 | | - \left[ \left\langle V(t) \right\rangle_t \right]^2 |
132 | | -
|
133 | | - where :math:`\left\langle \ldots \right\rangle_t = (1 / T_m) \int_0^{T_m} dt \, \ldots` |
134 | | - denotes time-averaging over a large time, :math:`\tau_m`. After normalization |
135 | | - of :math:`\sigma_V` to the average over the population of the single cell |
136 | | - membrane potentials |
137 | | -
|
138 | | - .. math:: |
139 | | -
|
140 | | - \sigma_{V_i}^2 = \left\langle\left[ V_i(t) \right]^2 \right\rangle_t - |
141 | | - \left[ \left\langle V_i(t) \right\rangle_t \right]^2 |
142 | | -
|
143 | | - one defines a synchrony measure, :math:`\chi (N)`, for the activity of a system |
144 | | - of :math:`N` neurons by: |
145 | | -
|
146 | | - .. math:: |
147 | | -
|
148 | | - \chi^2 \left( N \right) = \frac{\sigma_V^2}{ \frac{1}{N} \sum_{i=1}^N |
149 | | - \sigma_{V_i}^2} |
150 | | -
|
151 | | - .. [1] Golomb, D. and Rinzel J. (1993) Dynamics of globally coupled |
152 | | - inhibitory neurons with heterogeneity. Phys. Rev. E 48:4810-4814. |
153 | | - .. [2] Golomb D. and Rinzel J. (1994) Clustering in globally coupled |
154 | | - inhibitory neurons. Physica D 72:259-282. |
155 | | - .. [3] David Golomb (2007) Neuronal synchrony measures. Scholarpedia, 2(1):1347. |
156 | | -
|
157 | | - Args: |
158 | | - potentials: The membrane potential matrix of the neuron group. |
159 | | - numpy: Whether we use numpy array as the functional output. If ``False``, this function can be JIT compiled. |
160 | | - method: The method to calculate all pairs of cross correlation. |
161 | | - Supports two kinds of methods: `loop` and `vmap`. |
162 | | - `vmap` method will consume much more memory. |
163 | | -
|
164 | | - .. versionadded:: 2.2.3.4 |
165 | | -
|
166 | | - Returns: |
167 | | - sync_index: The synchronization index. |
168 | | - """ |
169 | | - |
170 | | - potentials = bm.as_jax(potentials) |
171 | | - avg = jnp.mean(potentials, axis=1) |
172 | | - avg_var = jnp.mean(avg * avg) - jnp.mean(avg) ** 2 |
173 | | - |
174 | | - if method == 'loop': |
175 | | - _var = bm.for_loop(_f_signal, operands=jnp.moveaxis(potentials, 0, 1)) |
176 | | - |
177 | | - elif method == 'vmap': |
178 | | - _var = vmap(_f_signal, in_axes=1)(potentials) |
179 | | - else: |
180 | | - raise UnsupportedError(f'Do not support {method}. We only support "loop" or "vmap".') |
181 | | - |
182 | | - var_mean = jnp.mean(_var) |
183 | | - r = jnp.where(var_mean == 0., 1., avg_var / var_mean) |
184 | | - return bm.as_numpy(r) if numpy else r |
185 | | - |
186 | | - |
187 | | -def matrix_correlation(x, y, numpy=True): |
188 | | - """Pearson correlation of the lower triagonal of two matrices. |
189 | | -
|
190 | | - The triangular matrix is offset by k = 1 in order to ignore the diagonal line |
191 | | -
|
192 | | - Parameters:: |
193 | | -
|
194 | | - x: ndarray |
195 | | - First matrix. |
196 | | - y: ndarray |
197 | | - Second matrix |
198 | | - numpy: bool |
199 | | - Whether we use numpy array as the functional output. |
200 | | - If ``False``, this function can be JIT compiled. |
201 | | -
|
202 | | - Returns:: |
203 | | -
|
204 | | - coef: ndarray |
205 | | - Correlation coefficient |
206 | | - """ |
207 | | - |
208 | | - x = bm.as_numpy(x) if numpy else bm.as_device_array(x) |
209 | | - y = bm.as_numpy(y) if numpy else bm.as_device_array(y) |
210 | | - np = onp if numpy else jnp |
211 | | - if x.ndim != 2: |
212 | | - raise ValueError(f'Only support 2d array, but we got a array ' |
213 | | - f'with the shape of {x.shape}') |
214 | | - if y.ndim != 2: |
215 | | - raise ValueError(f'Only support 2d array, but we got a array ' |
216 | | - f'with the shape of {y.shape}') |
217 | | - x = x[np.triu_indices_from(x, k=1)] |
218 | | - y = y[np.triu_indices_from(y, k=1)] |
219 | | - cc = np.corrcoef(x, y)[0, 1] |
220 | | - return cc |
221 | | - |
222 | | - |
223 | | -def functional_connectivity(activities, numpy=True): |
224 | | - """Functional connectivity matrix of timeseries activities. |
225 | | -
|
226 | | - Parameters:: |
227 | | -
|
228 | | - activities: ndarray |
229 | | - The multidimensional array with the shape of ``(num_time, num_sample)``. |
230 | | - numpy: bool |
231 | | - Whether we use numpy array as the functional output. |
232 | | - If ``False``, this function can be JIT compiled. |
233 | | -
|
234 | | - Returns:: |
235 | | -
|
236 | | - connectivity_matrix: ndarray |
237 | | - ``num_sample x num_sample`` functional connectivity matrix. |
238 | | - """ |
239 | | - activities = bm.as_numpy(activities) if numpy else bm.as_device_array(activities) |
240 | | - np = onp if numpy else jnp |
241 | | - if activities.ndim != 2: |
242 | | - raise ValueError('Only support 2d array with shape of "(num_time, num_sample)". ' |
243 | | - f'But we got a array with the shape of {activities.shape}') |
244 | | - fc = np.corrcoef(activities.T) |
245 | | - return np.nan_to_num(fc) |
246 | | - |
247 | | - |
248 | | -def functional_connectivity_dynamics(activities, window_size=30, step_size=5): |
249 | | - """Computes functional connectivity dynamics (FCD) matrix. |
250 | | -
|
251 | | - Parameters:: |
252 | | -
|
253 | | - activities: ndarray |
254 | | - The time series with shape of ``(num_time, num_sample)``. |
255 | | - window_size: int |
256 | | - Size of each rolling window in time steps, defaults to 30. |
257 | | - step_size: int |
258 | | - Step size between each rolling window, defaults to 5. |
259 | | -
|
260 | | - Returns:: |
261 | | -
|
262 | | - fcd_matrix: ndarray |
263 | | - FCD matrix. |
264 | | - """ |
265 | | - pass |
266 | | - |
267 | | - |
268 | | -def weighted_correlation(x, y, w, numpy=True): |
269 | | - """Weighted Pearson correlation of two data series. |
270 | | -
|
271 | | - Parameters:: |
272 | | -
|
273 | | - x: ndarray |
274 | | - The data series 1. |
275 | | - y: ndarray |
276 | | - The data series 2. |
277 | | - w: ndarray |
278 | | - Weight vector, must have same length as x and y. |
279 | | - numpy: bool |
280 | | - Whether we use numpy array as the functional output. |
281 | | - If ``False``, this function can be JIT compiled. |
282 | | -
|
283 | | - Returns:: |
284 | | -
|
285 | | - corr: ndarray |
286 | | - Weighted correlation coefficient. |
287 | | - """ |
288 | | - x = bm.as_numpy(x) if numpy else bm.as_device_array(x) |
289 | | - y = bm.as_numpy(y) if numpy else bm.as_device_array(y) |
290 | | - w = bm.as_numpy(w) if numpy else bm.as_device_array(w) |
291 | | - np = onp if numpy else jnp |
292 | | - |
293 | | - def _weighted_mean(x, w): |
294 | | - """Weighted Mean""" |
295 | | - return np.sum(x * w) / np.sum(w) |
296 | | - |
297 | | - def _weighted_cov(x, y, w): |
298 | | - """Weighted Covariance""" |
299 | | - return np.sum(w * (x - _weighted_mean(x, w)) * (y - _weighted_mean(y, w))) / np.sum(w) |
300 | | - |
301 | | - if x.ndim != 1: |
302 | | - raise ValueError(f'Only support 1d array, but we got a array ' |
303 | | - f'with the shape of {x.shape}') |
304 | | - if y.ndim != 1: |
305 | | - raise ValueError(f'Only support 1d array, but we got a array ' |
306 | | - f'with the shape of {y.shape}') |
307 | | - if w.ndim != 1: |
308 | | - raise ValueError(f'Only support 1d array, but we got a array ' |
309 | | - f'with the shape of {w.shape}') |
310 | | - return _weighted_cov(x, y, w) / np.sqrt(_weighted_cov(x, x, w) * _weighted_cov(y, y, w)) |
| 15 | +cross_correlation = braintools.metric.cross_correlation |
| 16 | +voltage_fluctuation = braintools.metric.voltage_fluctuation |
| 17 | +matrix_correlation = braintools.metric.matrix_correlation |
| 18 | +functional_connectivity = braintools.metric.functional_connectivity |
| 19 | +weighted_correlation = braintools.metric.weighted_correlation |
0 commit comments