Skip to content

Commit 2176a3c

Browse files
author
Gereon Elvers
committed
Armeni / Schoffelen: chunked preprocessing for Colab-class memory budgets
The previous _preprocess_raw_to_h5 called raw.load_data() on the full recording before running notch + bp + ds. For Armeni sessions (~8 GB CTF, ~9 GB float64 preloaded) that OOM-killed free Colab runtimes before any pipeline step ran. Now we: - open the recording lazily, drop reference / EEG / EOG / STIM channels (mne pick("meg")) before any data lands in RAM; - iterate the recording in chunk_seconds (default 120 s) windows; - preload + run the full pipeline (notch + bp + ds) inside each chunk so each filter pass sees a chunk-sized signal; - after the chunk passes through ds, downcast to float32 — filters are done and the H5 serializer writes float32 anyway, so the cast is lossless relative to the on-disk format and halves the accumulated chunk memory; - merge tiny last-chunk remainders into the previous chunk so mne is never asked to filter a sub-second tail (was triggering filter_length > signal warnings on 305 s recordings). Verified end-to-end on the VM: - Armeni sub-001/ses-001/task-compr (8 GB CTF): fresh build 430 s, peak RSS 4.81 GB (was 10.08 GB before chunking, OOM in free Colab). 29 460 phoneme samples, 298 MEG channels. - Schoffelen sub-A2002/task-rest (530 MB CTF): fresh build 86 s, peak RSS 2.28 GB. 31 trial samples, 301 channels. Both fit comfortably under 12 GB now.
1 parent 0941304 commit 2176a3c

2 files changed

Lines changed: 173 additions & 37 deletions

File tree

pnpl/datasets/armeni2022/dataset.py

Lines changed: 104 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -387,49 +387,112 @@ def _serialise_fif_to_h5(self, fif_path: str, output_h5_path: str) -> None:
387387
os.makedirs(os.path.dirname(output_h5_path), exist_ok=True)
388388
fif_to_h5(raw, output_h5_path)
389389

390+
def _load_raw_lazy_meg_only(
391+
self, subject: str, session: str, task: str, run: str,
392+
):
393+
"""Open the recording without preloading and pick MEG channels.
394+
395+
Reference / EEG / EOG / STIM channels are dropped here (the
396+
events.tsv-driven tasks don't use them). The returned Raw is
397+
not preloaded — caller iterates time chunks via ``crop``+
398+
``load_data`` so peak memory stays bounded for multi-GB
399+
recordings like Armeni's 8 GB sessions.
400+
"""
401+
raw = self.load_raw_bids(subject, session, task, run, preload=False)
402+
try:
403+
raw.pick(picks="meg", exclude=[])
404+
except Exception:
405+
pass
406+
return raw
407+
390408
def _preprocess_raw_to_h5(
391409
self,
392410
subject: str,
393411
session: str,
394412
task: str,
395413
run: str,
396414
output_h5_path: str,
415+
chunk_seconds: float = 120.0,
397416
) -> None:
417+
import gc
418+
419+
import mne
420+
import numpy as np
421+
398422
from ...preprocessing import Pipeline
399423
from ...preprocessing.config import (
400424
load_json_config,
401425
resolve_preprocessing_config,
402426
)
403427
from ...preprocessing.serialization import fif_to_h5
404428

405-
raw = self.load_raw_bids(subject, session, task, run, preload=True)
429+
raw_lazy = self._load_raw_lazy_meg_only(subject, session, task, run)
430+
431+
if self.preprocessing is None:
432+
# No filtering / downsampling — preload everything and
433+
# serialize as-is. Caller is expected to size their machine
434+
# to the recording.
435+
raw_lazy.load_data(verbose=False)
436+
os.makedirs(os.path.dirname(output_h5_path), exist_ok=True)
437+
fif_to_h5(raw_lazy, output_h5_path)
438+
return
439+
440+
step_names = self.preprocessing.split("+")
441+
json_config = load_json_config(self.data_path)
442+
resolved = resolve_preprocessing_config(
443+
step_names=step_names,
444+
json_config=json_config,
445+
dataset_config=self.preprocessing_config,
446+
)
406447

407-
if self.preprocessing is not None:
408-
step_names = self.preprocessing.split("+")
409-
json_config = load_json_config(self.data_path)
410-
resolved = resolve_preprocessing_config(
411-
step_names=step_names,
412-
json_config=json_config,
413-
dataset_config=self.preprocessing_config,
448+
# Run the pipeline chunk-wise so the preload step never exceeds
449+
# ``chunk_seconds`` worth of data. Each chunk emerges from the
450+
# pipeline at the target sample rate (typically 250 Hz),
451+
# dramatically smaller than the source. Concatenating the
452+
# processed chunks rebuilds the full timeline for
453+
# ``fif_to_h5``.
454+
duration = float(raw_lazy.times[-1])
455+
boundaries = _chunk_boundaries(duration, chunk_seconds)
456+
457+
processed_chunks: list = []
458+
for start, end in boundaries:
459+
chunk = raw_lazy.copy().crop(tmin=start, tmax=end)
460+
chunk.load_data(verbose=False)
461+
pipeline = Pipeline.from_string(
462+
self.preprocessing, config=resolved.config
414463
)
415-
pipeline = Pipeline.from_string(self.preprocessing, config=resolved.config)
416-
raw = pipeline.run(
417-
raw,
464+
chunk = pipeline.run(
465+
chunk,
418466
subject=subject,
419467
session=session,
420468
task=task,
421469
run=run,
422470
bids_root=self.data_path,
423471
verbose=False,
424472
)
425-
426-
fif_path = self.get_preprocessed_path(
427-
subject, session, task, run,
428-
preprocessing=self.preprocessing,
429-
extension="fif",
430-
)
431-
os.makedirs(os.path.dirname(fif_path), exist_ok=True)
432-
raw.save(fif_path, overwrite=True, verbose=False)
473+
# Filters are done; downcast each chunk to float32 to halve
474+
# the memory cost of holding all processed chunks before
475+
# concatenation. fif_to_h5 saves float32 anyway, so this is
476+
# lossless relative to the on-disk format.
477+
if chunk._data is not None and chunk._data.dtype != np.float32:
478+
chunk._data = chunk._data.astype(np.float32, copy=False)
479+
processed_chunks.append(chunk)
480+
gc.collect()
481+
482+
if len(processed_chunks) == 1:
483+
raw = processed_chunks[0]
484+
else:
485+
raw = mne.concatenate_raws(processed_chunks)
486+
del processed_chunks
487+
gc.collect()
488+
489+
fif_path = self.get_preprocessed_path(
490+
subject, session, task, run,
491+
preprocessing=self.preprocessing,
492+
extension="fif",
493+
)
494+
os.makedirs(os.path.dirname(fif_path), exist_ok=True)
495+
raw.save(fif_path, overwrite=True, verbose=False)
433496

434497
os.makedirs(os.path.dirname(output_h5_path), exist_ok=True)
435498
fif_to_h5(raw, output_h5_path)
@@ -491,6 +554,28 @@ def n_times(self) -> int:
491554
return self.points_per_sample
492555

493556

557+
def _chunk_boundaries(duration: float, chunk_seconds: float) -> List[tuple]:
558+
"""Return ``[(t0, t1), ...]`` covering ``[0, duration]`` such that
559+
every interval is at least ``chunk_seconds * 0.5`` long. A short
560+
remainder gets folded into the previous interval — otherwise the
561+
last chunk can be too brief for mne's notch / bp filter design and
562+
triggers ``filter_length is longer than the signal`` distortion."""
563+
if duration <= 0:
564+
return []
565+
min_chunk = chunk_seconds * 0.5
566+
out: list[tuple] = []
567+
start = 0.0
568+
while start < duration:
569+
end = min(start + chunk_seconds, duration)
570+
out.append((start, end))
571+
start = end
572+
if len(out) >= 2 and (out[-1][1] - out[-1][0]) < min_chunk:
573+
last_end = out[-1][1]
574+
out.pop()
575+
out[-1] = (out[-1][0], last_end)
576+
return out
577+
578+
494579
def _apply_component_filters(
495580
run_keys: List[tuple],
496581
*,

pnpl/datasets/schoffelen2019/dataset.py

Lines changed: 69 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -353,49 +353,100 @@ def _serialise_fif_to_h5(self, fif_path: str, output_h5_path: str) -> None:
353353
os.makedirs(os.path.dirname(output_h5_path), exist_ok=True)
354354
fif_to_h5(raw, output_h5_path)
355355

356+
def _load_raw_lazy_meg_only(
357+
self, subject: str, session: str, task: str, run: str,
358+
):
359+
"""Open the recording without preloading, pick MEG only.
360+
361+
Caller iterates time chunks via ``crop`` + ``load_data`` so
362+
peak memory stays bounded for multi-GB recordings.
363+
"""
364+
raw = self.load_raw_bids(subject, session, task, run, preload=False)
365+
try:
366+
raw.pick(picks="meg", exclude=[])
367+
except Exception:
368+
pass
369+
return raw
370+
356371
def _preprocess_raw_to_h5(
357372
self,
358373
subject: str,
359374
session: str,
360375
task: str,
361376
run: str,
362377
output_h5_path: str,
378+
chunk_seconds: float = 120.0,
363379
) -> None:
380+
import gc
381+
382+
import mne
383+
import numpy as np
384+
364385
from ...preprocessing import Pipeline
365386
from ...preprocessing.config import (
366387
load_json_config,
367388
resolve_preprocessing_config,
368389
)
369390
from ...preprocessing.serialization import fif_to_h5
370391

371-
raw = self.load_raw_bids(subject, session, task, run, preload=True)
392+
raw_lazy = self._load_raw_lazy_meg_only(subject, session, task, run)
372393

373-
if self.preprocessing is not None:
374-
step_names = self.preprocessing.split("+")
375-
json_config = load_json_config(self.data_path)
376-
resolved = resolve_preprocessing_config(
377-
step_names=step_names,
378-
json_config=json_config,
379-
dataset_config=self.preprocessing_config,
394+
if self.preprocessing is None:
395+
raw_lazy.load_data(verbose=False)
396+
os.makedirs(os.path.dirname(output_h5_path), exist_ok=True)
397+
fif_to_h5(raw_lazy, output_h5_path)
398+
return
399+
400+
step_names = self.preprocessing.split("+")
401+
json_config = load_json_config(self.data_path)
402+
resolved = resolve_preprocessing_config(
403+
step_names=step_names,
404+
json_config=json_config,
405+
dataset_config=self.preprocessing_config,
406+
)
407+
408+
# Chunk-wise pipeline application — see Armeni2022 for the
409+
# rationale. Each chunk is preloaded, processed (notch + bp +
410+
# ds), and discarded except for the downsampled output.
411+
from ..armeni2022.dataset import _chunk_boundaries
412+
duration = float(raw_lazy.times[-1])
413+
boundaries = _chunk_boundaries(duration, chunk_seconds)
414+
415+
processed_chunks: list = []
416+
for start, end in boundaries:
417+
chunk = raw_lazy.copy().crop(tmin=start, tmax=end)
418+
chunk.load_data(verbose=False)
419+
pipeline = Pipeline.from_string(
420+
self.preprocessing, config=resolved.config
380421
)
381-
pipeline = Pipeline.from_string(self.preprocessing, config=resolved.config)
382-
raw = pipeline.run(
383-
raw,
422+
chunk = pipeline.run(
423+
chunk,
384424
subject=subject,
385425
session=session,
386426
task=task,
387427
run=run,
388428
bids_root=self.data_path,
389429
verbose=False,
390430
)
431+
if chunk._data is not None and chunk._data.dtype != np.float32:
432+
chunk._data = chunk._data.astype(np.float32, copy=False)
433+
processed_chunks.append(chunk)
434+
gc.collect()
391435

392-
fif_path = self.get_preprocessed_path(
393-
subject, session, task, run,
394-
preprocessing=self.preprocessing,
395-
extension="fif",
396-
)
397-
os.makedirs(os.path.dirname(fif_path), exist_ok=True)
398-
raw.save(fif_path, overwrite=True, verbose=False)
436+
if len(processed_chunks) == 1:
437+
raw = processed_chunks[0]
438+
else:
439+
raw = mne.concatenate_raws(processed_chunks)
440+
del processed_chunks
441+
gc.collect()
442+
443+
fif_path = self.get_preprocessed_path(
444+
subject, session, task, run,
445+
preprocessing=self.preprocessing,
446+
extension="fif",
447+
)
448+
os.makedirs(os.path.dirname(fif_path), exist_ok=True)
449+
raw.save(fif_path, overwrite=True, verbose=False)
399450

400451
os.makedirs(os.path.dirname(output_h5_path), exist_ok=True)
401452
fif_to_h5(raw, output_h5_path)

0 commit comments

Comments
 (0)