Skip to content

Commit 3deb4d8

Browse files
Ramdam17claude
andauthored
docs(sync): enrich GPU metric and kernel docstrings (#278)
Restore the richer module docstrings (intent + literature references) for the 9 sync metrics, base.py, and the package __init__; expand the CUDA/Metal kernel module docstrings (kernel list, fp64-on-A100 / fp32-on-Metal rationale, why the pairwise dispatch avoids OOM); restore the _VRAM_THRESHOLD attribute docstring and enrich run_pairwise_kernel. mkdocstrings cross-refs are flattened to plain text for tooltip legibility; no reference to non-existent docs files. Docstrings only — no behavior change. sync parity tests pass (74 passed / 9 CuPy skipped on M4 Max). Co-authored-by: Claude Fable 5 <noreply@anthropic.com>
1 parent f720414 commit 3deb4d8

16 files changed

Lines changed: 255 additions & 30 deletions

hypyp/sync/__init__.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,41 @@
44
"""
55
Synchrony and connectivity metrics for hyperscanning analysis.
66
7-
This module provides a collection of connectivity metrics that can be used
8-
to measure neural synchronization between participants.
7+
Public API
8+
----------
9+
``BaseMetric``
10+
Abstract base class. Concrete metrics inherit from it and implement
11+
``BaseMetric.compute``.
12+
13+
Concrete metric classes (one per file):
14+
15+
- ``PLV`` (``hypyp.sync.plv``) — Phase Locking Value.
16+
- ``CCorr`` (``hypyp.sync.ccorr``) — Circular Correlation.
17+
- ``ACCorr`` (``hypyp.sync.accorr``) — Adjusted Circular Correlation.
18+
- ``Coh`` (``hypyp.sync.coh``) — Coherence.
19+
- ``ImCoh`` (``hypyp.sync.imaginary_coh``) — Imaginary Coherence.
20+
- ``PLI`` (``hypyp.sync.pli``) — Phase Lag Index.
21+
- ``WPLI`` (``hypyp.sync.wpli``) — Weighted Phase Lag Index.
22+
- ``EnvCorr`` (``hypyp.sync.envelope_corr``) — Envelope Correlation.
23+
- ``PowCorr`` (``hypyp.sync.pow_corr``) — Power Correlation.
24+
25+
Helpers
26+
-------
27+
``multiply_conjugate``, ``multiply_conjugate_time``,
28+
``multiply_product``, ``multiply_conjugate_torch``,
29+
``multiply_conjugate_time_torch``
30+
Einsum building blocks shared across the metric implementations.
31+
32+
Dispatcher
33+
----------
34+
``METRICS``
35+
Dict mapping mode strings to metric classes.
36+
``get_metric``
37+
Lookup helper used by ``hypyp.eeg.analyses.compute_sync`` and
38+
``hypyp.eeg.analyses.pair_connectivity``.
39+
40+
Per-metric mathematical details and references live in each metric
41+
class's docstring.
942
"""
1043

1144
from typing import Optional

hypyp/sync/accorr.py

Lines changed: 33 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,22 @@
44
"""
55
Adjusted Circular Correlation (ACCorr) connectivity metric.
66
7-
ACCorr computes the circular correlation between two phase time-series with
8-
per-pair phase centering, providing a more accurate inter-brain synchrony
9-
estimate than standard circular correlation (ccorr).
10-
11-
Reference: Zimmermann et al. (2024). *Imaging Neuroscience*, 2.
7+
ACCorr computes the circular correlation between two phase time-series
8+
with **per-pair** phase centering, providing a more accurate inter-brain
9+
synchrony estimate than standard circular correlation (ccorr).
10+
11+
See the ``ACCorr`` class for the public API; it supports a CPU
12+
``precompute`` strategy (vectorised numerator + loop denominator with
13+
pre-computed per-pair adjustments), a numba JIT backend, and PyTorch
14+
GPU/MPS backends. The torch implementation switches between a fully
15+
vectorised 5-D broadcast and a per-pair loop based on
16+
``ACCorr._VRAM_THRESHOLD`` — see that attribute's docstring.
17+
18+
References
19+
----------
20+
Zimmermann, M., Schultz-Nielsen, K., Dumas, G., & Konvalinka, I. (2024).
21+
Arbitrary methodological decisions skew inter-brain synchronization
22+
estimates in hyperscanning-EEG studies. *Imaging Neuroscience*, 2.
1223
https://doi.org/10.1162/imag_a_00350
1324
1425
Credits
@@ -211,9 +222,24 @@ def _compute_numba(self, complex_signal: np.ndarray, n_samp: int,
211222

212223
return con
213224

214-
# Memory threshold for vectorized denominator (bytes). If the 5D tensor
215-
# (E, F, C, C, T) would exceed this, fall back to the loop-based approach.
216225
_VRAM_THRESHOLD = 2 * 1024**3 # 2 GB
226+
"""
227+
Memory threshold (bytes) for the vectorised torch denominator path.
228+
229+
Notes
230+
-----
231+
The torch implementation prefers a fully-vectorised broadcast over the
232+
intermediate 5-D tensor of shape ``(n_epochs, n_freq, n_channels,
233+
n_channels, n_samples)`` for the per-pair phase centering. When the
234+
estimated tensor size in bytes exceeds this threshold, ``_compute_torch``
235+
falls back to a per-pair loop on the same device (CPU / MPS / CUDA).
236+
237+
The 2 GB default is sized to keep one such tensor comfortably under
238+
Apple-Silicon MPS and Quadro-class GPU memory budgets when the rest of
239+
the pipeline (data tensors, kernel state) is already resident — a 4 GB
240+
threshold can OOM on high-channel-count realistic_hd benchmarks. This
241+
value is empirical; re-derive if you change the upstream tensor layout.
242+
"""
217243

218244
def _compute_torch(self, complex_signal: np.ndarray, n_samp: int,
219245
transpose_axes: tuple) -> np.ndarray:

hypyp/sync/base.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,44 @@
22
# coding=utf-8
33

44
"""
5-
Base classes and helper functions for connectivity metrics.
6-
7-
| Option | Description |
8-
| ------ | ----------- |
9-
| title | base.py |
10-
| authors | HyPyP Team |
11-
| date | 2026-01-30 |
5+
Base classes, einsum helpers, optional-dependency probing, and the
6+
AUTO_PRIORITY benchmark dispatch table for the connectivity metrics.
7+
8+
This module is shared by every concrete metric in ``hypyp.sync``. It
9+
exposes:
10+
11+
- ``BaseMetric`` — abstract base. Concrete metrics override
12+
``BaseMetric.compute`` and rely on the shared backend-resolution and
13+
warning-fallback logic.
14+
- ``multiply_conjugate``, ``multiply_conjugate_time``,
15+
``multiply_product`` — vectorised einsum kernels (numpy).
16+
- ``multiply_conjugate_torch``, ``multiply_conjugate_time_torch`` —
17+
torch equivalents (only resolvable if torch is installed).
18+
- ``AUTO_PRIORITY`` — benchmark-driven backend lookup table per
19+
``{metric_name: {platform: [gpu_backend, fallback]}}``.
20+
- Capability flags ``TORCH_AVAILABLE``, ``MPS_AVAILABLE``,
21+
``CUDA_AVAILABLE``, ``NUMBA_AVAILABLE``, ``METAL_AVAILABLE``,
22+
``CUPY_AVAILABLE`` — probed at import time so concrete metric classes
23+
don't have to retry.
24+
25+
Design note
26+
-----------
27+
``AUTO_PRIORITY`` is intentionally kept as a Python dict (not
28+
externalised to a YAML file) for three reasons:
29+
30+
1. The values are not user-tunable — they are derived from benchmarks
31+
on Mac M4 Max (131 rows) and Narval A100 (111 rows) and need
32+
re-derivation if the kernels change. Putting them in YAML would
33+
wrongly suggest they are configuration knobs.
34+
2. The per-call ``priority=`` kwarg on ``get_metric`` already provides
35+
the override path users actually need.
36+
3. The table is short (9 entries) and sits next to its rationale
37+
comment block; a YAML file would split the explanation from the
38+
data.
39+
40+
If a future benchmark sweep changes the optimal backend, the change
41+
should be a code edit (with a tests/benchmarks update) — not a config
42+
change.
1243
"""
1344

1445
import warnings

hypyp/sync/ccorr.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,15 @@
33

44
"""
55
Circular Correlation (CCorr) connectivity metric.
6+
7+
CCorr is the circular analogue of Pearson's r: it measures the linear
8+
correlation between the sines of phase deviations from a global
9+
circular mean. See the ``CCorr`` class for the public API.
10+
11+
References
12+
----------
13+
Fisher, N. I. (1995). *Statistical Analysis of Circular Data*.
14+
Cambridge University Press.
615
"""
716

817
import numpy as np

hypyp/sync/coh.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33

44
"""
55
Coherence (Coh) connectivity metric.
6+
7+
Coh measures the linear relationship between two complex analytic
8+
signals in the frequency domain — it is the squared modulus of the
9+
cross-spectrum normalised by the product of the auto-spectra. See the
10+
``Coh`` class for the public API.
11+
12+
References
13+
----------
14+
Nunez, P. L., & Srinivasan, R. (2006). *Electric Fields of the Brain:
15+
The Neurophysics of EEG* (2nd ed.). Oxford University Press.
616
"""
717

818
import numpy as np

hypyp/sync/envelope_corr.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33

44
"""
55
Envelope Correlation (EnvCorr) connectivity metric.
6+
7+
EnvCorr is the Pearson correlation between the analytic-signal
8+
amplitude (envelope) of two channels — it captures slow co-modulation
9+
of band-limited power. See the ``EnvCorr`` class for the public API.
10+
11+
References
12+
----------
13+
Hipp, J. F., Hawellek, D. J., Corbetta, M., Siegel, M., & Engel, A. K.
14+
(2012). Large-scale cortical correlation structure of spontaneous
15+
oscillatory activity. *Nature Neuroscience*, 15(6), 884-890.
616
"""
717

818
import numpy as np

hypyp/sync/imaginary_coh.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@
33

44
"""
55
Imaginary Coherence (ImCoh) connectivity metric.
6+
7+
ImCoh isolates the imaginary part of the cross-spectrum normalised by
8+
the auto-spectra, which makes it insensitive to zero-lag interactions
9+
(volume conduction). See the ``ImCoh`` class for the public API.
10+
11+
References
12+
----------
13+
Nolte, G., Bai, O., Wheaton, L., Mari, Z., Vorbach, S., & Hallett, M.
14+
(2004). Identifying true brain interaction from EEG data using the
15+
imaginary part of coherency. *Clinical Neurophysiology*, 115(10),
16+
2292-2307.
617
"""
718

819
import numpy as np

hypyp/sync/kernels/_cuda_dispatch.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
"""
22
Shared CUDA dispatch logic for all pairwise sync metric kernels.
33
4-
Uses CuPy RawKernel for inline CUDA source. All kernels use float64
5-
for exact precision (A100 has 9.7 TFLOPS fp64).
4+
Provides ``run_pairwise_kernel``, the common launch + memory-management
5+
routine used by every metric-specific CUDA kernel module
6+
(``cuda_phase``, ``cuda_amplitude``, ``cuda_accorr``).
7+
8+
Uses CuPy ``RawKernel`` for inline CUDA source. All kernels run in
9+
float64 — the NVIDIA A100 reference target has 9.7 TFLOPS of fp64
10+
throughput, so the precision/speed trade-off favours fp64 there.
611
"""
712

813
import numpy as np
@@ -17,14 +22,28 @@ def run_pairwise_kernel(complex_signal, get_kernel_fn):
1722
"""
1823
Shared dispatch for pairwise CUDA kernels.
1924
25+
Builds the upper-triangle channel-pair index list, transfers the real
26+
and imaginary parts to the device as float64, launches one thread per
27+
``(epoch*freq, pair)`` tuple, and reads the result back. Computing
28+
pairwise — rather than materialising the full ``(E, F, C, C, T)``
29+
cross-spectrum — keeps device memory bounded at high channel counts.
30+
2031
Parameters
2132
----------
2233
complex_signal : np.ndarray, shape (E, F, C, T)
23-
get_kernel_fn : callable -> CuPy RawKernel
34+
Complex analytic signals (epochs, freqs, channels, samples).
35+
get_kernel_fn : callable -> cupy.RawKernel
36+
Lazily compiles (and caches) the metric-specific CUDA kernel.
2437
2538
Returns
2639
-------
2740
np.ndarray, shape (E, F, C, C), float64
41+
Connectivity matrix per (epoch, freq).
42+
43+
Notes
44+
-----
45+
The CuPy default memory pool is explicitly freed before returning, so
46+
repeated calls in a tight loop don't accumulate device allocations.
2847
"""
2948
kernel = get_kernel_fn()
3049

hypyp/sync/kernels/cuda_accorr.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
"""
22
CUDA kernel for ACCorr (Adjusted Circular Correlation).
3-
Float64 for exact precision on NVIDIA GPUs.
43
5-
ACCorr requires a custom dispatch (not run_pairwise_kernel) because
6-
it needs an extra angle buffer for the sin^2 denominator in pass 2.
4+
ACCorr requires a **custom** dispatch rather than the shared
5+
``_cuda_dispatch.run_pairwise_kernel`` because it needs an extra angle
6+
buffer for the ``sin²`` adjusted-phase denominator in pass 2. Float64
7+
throughout for exact precision on NVIDIA GPUs (the A100 reference
8+
target has 9.7 TFLOPS of fp64).
79
"""
810

911
import numpy as np

hypyp/sync/kernels/cuda_amplitude.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
"""
2-
CUDA kernels for amplitude-based sync metrics: Coh, ImCoh, EnvCorr, PowCorr.
3-
All float64 for exact precision on NVIDIA GPUs.
2+
CUDA kernels for amplitude-based sync metrics.
3+
4+
Implements:
5+
6+
- ``coh_cuda`` — magnitude-squared Coherence.
7+
- ``imcoh_cuda`` — Imaginary Coherence.
8+
- ``envcorr_cuda`` — Envelope Correlation.
9+
- ``powcorr_cuda`` — Power Correlation.
10+
11+
All kernels run in float64 for exact precision on NVIDIA GPUs. They
12+
share the pair-iteration scaffolding from
13+
``_cuda_dispatch.run_pairwise_kernel``.
414
"""
515

616
import numpy as np

0 commit comments

Comments
 (0)