Skip to content

Commit 66882ec

Browse files
committed
first draft script
1 parent ddd132b commit 66882ec

1 file changed

Lines changed: 255 additions & 0 deletions

File tree

src/pipelines/script_eyeflow.py

Lines changed: 255 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,255 @@
1+
from __future__ import annotations
2+
3+
import json
4+
5+
import h5py
6+
import numpy as np
7+
8+
from .core.base import ProcessPipeline, ProcessResult, registerPipeline, with_attrs
9+
10+
11+
@registerPipeline(name="script_eyeflow")
12+
class ScriptEyeFlow(ProcessPipeline):
13+
description = (
14+
"Generate artery and vein per-beat velocity datasets from EyeFlow "
15+
"segmentation masks and branch signals."
16+
)
17+
18+
HD_PARAMETERS = "HD_parameters"
19+
BRANCH_SIGNALS = "segmentation/Retina/binary/branch_signals"
20+
LABELED_VESSELS = "segmentation/Retina/binary/labeled_vessels"
21+
ARTERY_MASK = "segmentation/Retina/av/artery_mask"
22+
VEIN_MASK = "segmentation/Retina/av/vein_mask"
23+
HARMONIC_COUNT = 13
24+
ARTERY_VPB = "Artery/VelocityPerBeat"
25+
VEIN_VPB = "Vein/VelocityPerBeat"
26+
27+
@staticmethod
28+
def _pick_labels(
29+
h5file: h5py.File, vessel_kind: str, branch_count: int
30+
) -> list[int]:
31+
labels = np.asarray(h5file[ScriptEyeFlow.LABELED_VESSELS])
32+
artery = np.asarray(h5file[ScriptEyeFlow.ARTERY_MASK]) > 0
33+
vein = np.asarray(h5file[ScriptEyeFlow.VEIN_MASK]) > 0
34+
primary = artery if vessel_kind == "artery" else vein
35+
secondary = vein if vessel_kind == "artery" else artery
36+
picked: list[int] = []
37+
38+
for label in (int(value) for value in np.unique(labels) if value > 0):
39+
if label > branch_count:
40+
continue
41+
label_mask = labels == label
42+
if np.count_nonzero(label_mask & primary) > np.count_nonzero(
43+
label_mask & secondary
44+
):
45+
picked.append(label)
46+
47+
return picked
48+
49+
@staticmethod
50+
def _moving_average(values: np.ndarray, dt: float) -> np.ndarray:
51+
width = max(3, int(round(0.05 / dt)))
52+
if width % 2 == 0:
53+
width += 1
54+
kernel = np.ones(width, dtype=float) / width
55+
return np.convolve(values, kernel, mode="same")
56+
57+
@staticmethod
58+
def _find_peaks(values: np.ndarray, min_distance: int) -> np.ndarray:
59+
threshold = np.percentile(values, 95)
60+
peaks = (
61+
np.flatnonzero((values[1:-1] > values[:-2]) & (values[1:-1] >= values[2:]))
62+
+ 1
63+
)
64+
peaks = peaks[values[peaks] >= threshold]
65+
66+
kept: list[int] = []
67+
for peak in peaks:
68+
if not kept or peak - kept[-1] >= min_distance:
69+
kept.append(int(peak))
70+
71+
return np.asarray(kept, dtype=int)
72+
73+
@staticmethod
74+
def _load_hd_parameters(h5file: h5py.File) -> dict[str, object]:
75+
dataset = h5file.get(ScriptEyeFlow.HD_PARAMETERS)
76+
if dataset is None:
77+
return {}
78+
79+
payload = dataset[()]
80+
if isinstance(payload, np.ndarray) and payload.shape == ():
81+
payload = payload.item()
82+
if isinstance(payload, bytes):
83+
payload = payload.decode("utf-8")
84+
if not isinstance(payload, str):
85+
return {}
86+
87+
try:
88+
parsed = json.loads(payload)
89+
except json.JSONDecodeError:
90+
return {}
91+
return parsed if isinstance(parsed, dict) else {}
92+
93+
@classmethod
94+
def _infer_dt(cls, h5file: h5py.File) -> float:
95+
params = cls._load_hd_parameters(h5file)
96+
sampling_freq = float(
97+
h5file.attrs.get(
98+
"sampling_freq",
99+
h5file.attrs.get("fs", params.get("sampling_freq", params.get("fs", 0))),
100+
)
101+
)
102+
batch_step = float(
103+
h5file.attrs.get(
104+
"batch_size",
105+
h5file.attrs.get(
106+
"batch_stride",
107+
params.get("batch_size", params.get("batch_stride", 0)),
108+
),
109+
)
110+
)
111+
if sampling_freq <= 0 or batch_step <= 0:
112+
raise ValueError(
113+
"Could not infer dt. Expected sampling_freq and batch_size or "
114+
"batch_stride in root attrs or HD_parameters."
115+
)
116+
return batch_step / sampling_freq
117+
118+
@classmethod
119+
def _select_signal(cls, h5file: h5py.File, vessel: str) -> np.ndarray:
120+
branch_signals = np.asarray(h5file[cls.BRANCH_SIGNALS], dtype=float)
121+
labels = cls._pick_labels(h5file, vessel, branch_signals.shape[0])
122+
123+
if labels:
124+
signal = np.nanmean(branch_signals[np.asarray(labels) - 1], axis=0)
125+
else:
126+
signal = np.nanmean(branch_signals, axis=0)
127+
128+
return np.nan_to_num(signal - np.nanmean(signal))
129+
130+
@classmethod
131+
def _detect_systolic_peaks(cls, signal: np.ndarray, dt: float) -> np.ndarray:
132+
derivative = np.gradient(cls._moving_average(signal, dt))
133+
peaks = cls._find_peaks(derivative, max(1, int(0.5 / dt)))
134+
135+
if peaks.size < 2:
136+
raise ValueError("Could not detect at least two systolic peaks.")
137+
138+
return peaks
139+
140+
@staticmethod
141+
def _interp_cycle(beat: np.ndarray, n_fft: int) -> np.ndarray:
142+
x = np.arange(beat.size, dtype=float)
143+
xp = np.linspace(0.0, float(beat.size), n_fft + 1, endpoint=True)[:-1]
144+
return np.interp(xp, x, beat, period=float(beat.size))
145+
146+
@classmethod
147+
def _per_beat_signal_analysis(
148+
cls, signal: np.ndarray, sys_idx_list: np.ndarray
149+
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
150+
n_beats = sys_idx_list.size - 1
151+
n_fft = 1 << int(np.ceil(np.log2(np.max(np.diff(sys_idx_list)))))
152+
raw = np.full((n_beats, n_fft), np.nan, dtype=np.float32)
153+
fft_full = np.full((n_beats, n_fft), np.nan + 0j, dtype=np.complex64)
154+
band = np.full((n_beats, n_fft), np.nan, dtype=np.float32)
155+
156+
for beat_idx in range(n_beats):
157+
start = int(sys_idx_list[beat_idx])
158+
end = int(sys_idx_list[beat_idx + 1])
159+
beat = np.asarray(signal[start : end + 1], dtype=float)
160+
beat_interp = cls._interp_cycle(beat, n_fft)
161+
beat_fft = np.fft.fft(beat_interp, n_fft)
162+
163+
keep = min(max(1, cls.HARMONIC_COUNT), beat_fft.size)
164+
band_spectrum = np.zeros(n_fft, dtype=np.complex64)
165+
band_spectrum[:keep] = (2.0 * beat_fft[:keep]).astype(np.complex64)
166+
band_spectrum[0] = np.complex64(beat_fft[0])
167+
168+
raw[beat_idx] = beat_interp.astype(np.float32)
169+
fft_full[beat_idx] = beat_fft.astype(np.complex64)
170+
band[beat_idx] = np.abs(np.fft.ifft(band_spectrum, n_fft)).astype(
171+
np.float32
172+
)
173+
174+
return raw, fft_full, band
175+
176+
@staticmethod
177+
def _dataset_key(prefix: str, name: str) -> str:
178+
return f"{prefix}/{name}/value"
179+
180+
@classmethod
181+
def _velocity_per_beat_metrics(
182+
cls,
183+
prefix: str,
184+
signal: np.ndarray,
185+
sys_idx_list: np.ndarray,
186+
dt: float,
187+
) -> dict[str, object]:
188+
raw, fft_full, band = cls._per_beat_signal_analysis(signal, sys_idx_list)
189+
return {
190+
cls._dataset_key(prefix, "VelocitySignalPerBeat"): with_attrs(
191+
raw, {"unit": ["a.u."]}
192+
),
193+
cls._dataset_key(prefix, "VelocitySignalPerBeatFFT_abs"): with_attrs(
194+
np.abs(fft_full).astype(np.float32), {"unit": ["a.u."]}
195+
),
196+
cls._dataset_key(prefix, "VelocitySignalPerBeatFFT_arg"): with_attrs(
197+
np.angle(fft_full).astype(np.float32), {"unit": ["rad"]}
198+
),
199+
cls._dataset_key(prefix, "VelocitySignalPerBeatBandLimited"): with_attrs(
200+
band, {"unit": ["a.u."]}
201+
),
202+
cls._dataset_key(prefix, "VmaxPerBeatBandLimited"): with_attrs(
203+
np.max(band, axis=1).astype(np.float32), {"unit": ["a.u."]}
204+
),
205+
cls._dataset_key(prefix, "VminPerBeatBandLimited"): with_attrs(
206+
np.min(band, axis=1).astype(np.float32), {"unit": ["a.u."]}
207+
),
208+
cls._dataset_key(prefix, "VTIPerBeat"): with_attrs(
209+
(np.sum(raw, axis=1) * dt).astype(np.float32), {"unit": ["a.u.*s"]}
210+
),
211+
}
212+
213+
def run(self, h5file: h5py.File) -> ProcessResult:
214+
dt = self._infer_dt(h5file)
215+
if dt <= 0:
216+
raise ValueError("dt must be > 0.")
217+
218+
artery_signal = self._select_signal(h5file, "artery")
219+
vein_signal = self._select_signal(h5file, "vein")
220+
sys_idx_list = self._detect_systolic_peaks(artery_signal, dt)
221+
beat_count = int(sys_idx_list.size - 1)
222+
beat_period_idx = np.diff(sys_idx_list).astype(np.int32)[np.newaxis, :]
223+
beat_period_seconds = beat_period_idx.astype(np.float32) * np.float32(dt)
224+
225+
metrics: dict[str, object] = {
226+
self._dataset_key(self.ARTERY_VPB, "beatPeriodIdx"): with_attrs(
227+
beat_period_idx, {"unit": ["frames"]}
228+
),
229+
self._dataset_key(self.ARTERY_VPB, "beatPeriodSeconds"): with_attrs(
230+
beat_period_seconds, {"unit": ["s"]}
231+
),
232+
}
233+
metrics.update(
234+
self._velocity_per_beat_metrics(
235+
self.ARTERY_VPB,
236+
artery_signal,
237+
sys_idx_list,
238+
dt,
239+
)
240+
)
241+
metrics.update(
242+
self._velocity_per_beat_metrics(
243+
self.VEIN_VPB,
244+
vein_signal,
245+
sys_idx_list,
246+
dt,
247+
)
248+
)
249+
250+
attrs = {
251+
"dt_seconds": float(dt),
252+
"beat_count": beat_count,
253+
"harmonic_count": int(self.HARMONIC_COUNT),
254+
}
255+
return ProcessResult(metrics=metrics, attrs=attrs)

0 commit comments

Comments
 (0)