From 0ca0479a6e82490b16f61425639ffea946b5b80d Mon Sep 17 00:00:00 2001 From: Marija Zelic <145926170+masazelic@users.noreply.github.com> Date: Mon, 27 Apr 2026 09:43:59 +0200 Subject: [PATCH 1/3] Add fixed PanLUNA to repos cleared PanLUNA --- PanLUNA.md | 148 ++++ PanLUNA.py | 311 ++++++++ PanLUNA_finetune.yaml | 30 + PanLUNA_pretrain.yaml | 30 + README.md | 49 +- dataset_types.yaml | 33 + finetune_data_module_multimodal_PanLUNA.yaml | 109 +++ finetune_data_module_unimodal_PanLUNA.yaml | 49 ++ finetune_task_PanLUNA.py | 419 +++++++++++ finetune_task_PanLUNA.yaml | 22 + finetuning_multimodal_datasets_PanLUNA.py | 126 ++++ finetuning_unimodal_datasets_PanLUNA.py | 97 +++ lead_positions.py | 212 ++++++ make_code15_dataset.py | 146 ++++ make_cpsc2018_dataset.py | 147 ++++ make_csn_dataset.py | 157 ++++ make_hmc_dataset.py | 187 +++++ make_mimic_iv_dataset.py | 127 ++++ make_ptbxl_dataset.py | 158 ++++ make_pulsedb_dataset.py | 152 ++++ make_seed_vii_dataset.py | 202 ++++++ make_siena_dataset.py | 124 ++++ make_wesad_dataset.py | 274 +++++++ multiloader_data_module_PanLUNA.yaml | 720 +++++++++++++++++++ pretrain_task_PanLUNA.py | 273 +++++++ pretrain_task_PanLUNA.yaml | 22 + pretraining_datasets_PanLUNA.py | 96 +++ process_raw_ecg.py | 121 ++++ 28 files changed, 4540 insertions(+), 1 deletion(-) create mode 100644 PanLUNA.md create mode 100644 PanLUNA.py create mode 100644 PanLUNA_finetune.yaml create mode 100644 PanLUNA_pretrain.yaml create mode 100644 dataset_types.yaml create mode 100644 finetune_data_module_multimodal_PanLUNA.yaml create mode 100644 finetune_data_module_unimodal_PanLUNA.yaml create mode 100644 finetune_task_PanLUNA.py create mode 100644 finetune_task_PanLUNA.yaml create mode 100644 finetuning_multimodal_datasets_PanLUNA.py create mode 100644 finetuning_unimodal_datasets_PanLUNA.py create mode 100644 lead_positions.py create mode 100644 make_code15_dataset.py create mode 100644 make_cpsc2018_dataset.py create mode 100644 make_csn_dataset.py create mode 100644 make_hmc_dataset.py create mode 100644 make_mimic_iv_dataset.py create mode 100644 make_ptbxl_dataset.py create mode 100644 make_pulsedb_dataset.py create mode 100644 make_seed_vii_dataset.py create mode 100644 make_siena_dataset.py create mode 100644 make_wesad_dataset.py create mode 100644 multiloader_data_module_PanLUNA.yaml create mode 100644 pretrain_task_PanLUNA.py create mode 100644 pretrain_task_PanLUNA.yaml create mode 100644 pretraining_datasets_PanLUNA.py create mode 100644 process_raw_ecg.py diff --git a/PanLUNA.md b/PanLUNA.md new file mode 100644 index 0000000..312f07d --- /dev/null +++ b/PanLUNA.md @@ -0,0 +1,148 @@ +## PanLUNA + +PanLUNA is a self-supervised **pan-modal** biosignal foundation model that jointly processes **EEG, ECG, and PPG** within a single shared encoder. Extending LUNA's channel-unification module, PanLUNA treats multimodal channels as entries in a **unified query set augmented with sensor-type embeddings**, enabling efficient cross-modal early fusion while remaining inherently **robust to missing modalities** at inference time. Despite its compact 5.4M-parameter footprint, PanLUNA matches or exceeds models up to 57× larger, and supports quantization-aware INT8 deployment on the GAP9 ultra-low-power RISC-V microcontroller for continuous wearable monitoring. + +--- + +### Default Input Assumptions + +All modalities are resampled to 256 Hz and segmented into non-overlapping 5-second windows with a patch size of 32 samples, unless a downstream task specifies otherwise (e.g., 10-second windows for ECG benchmarks, 30-second epochs for sleep staging). + +| Modality | Channels | Native Sampling Rate | +|----------|----------|----------------------| +| EEG | 20–22 (pre-training); 29 (Siena) | 250–512 Hz | +| ECG | 12 (pre-training and cardiac benchmarks); 1 (HMC sleep staging) | 400–500 Hz | +| PPG | 1 (PulseDB) | 125 Hz | + +Missing modalities are handled natively at inference without any architectural modification. + +--- + +### Preprocessing + +A standardized modality-specific preprocessing pipeline is applied to all data: + +1. **Filtering**: Bandpass filtering with a 4th-order Butterworth filter, with modality-specific cutoffs: EEG 0.1–75 Hz; ECG 0.5–120 Hz; PPG 0.5–8 Hz. A notch filter (50 Hz or 60 Hz) is additionally applied. +2. **Resampling**: All signals resampled to 256 Hz. +3. **Normalization**: Per-channel z-score normalization, to account for the large amplitude differences across modalities (e.g., EEG in µV vs. ECG in mV). +4. **Segmentation**: Non-overlapping 5-second windows during pre-training; task-specific windowing during fine-tuning (e.g., 10-second windows for PTB-XL/CSN, 30-second epochs for HMC sleep staging). + +--- + +### Architecture Overview + +PanLUNA extends LUNA to the multimodal setting by generalizing topology invariance to cross-modal fusion. + +1. **Input Representation** + Channels from all modalities are concatenated along the channel dimension before entering the model. **Sensor-type embeddings** are introduced via a modality-specific lookup table, added to channel features at the input stage to distinguish sensing modalities. Channel positional encodings are modality-specific: + - **EEG**: Normalized 3D electrode coordinates encoded with sinusoidal embeddings (as in LUNA). + - **ECG**: Lead-angle estimates derived from [anatomical measurements on 30 body scans](https://www.ijcai.org/proceedings/2021/0495.pdf), constructing a spatial encoding analogous to EEG electrode positioning. + - **PPG**: Neutral coordinate (0, 0) assigned; the model relies on the sensor-type embedding for modality identification. + +2. **Patch Feature Extraction** + Signals are partitioned into short temporal patches and embedded via lightweight convolutional encoders combined with frequency features from the real-valued FFT. Patch-level features are augmented with positional encodings and sensor-type embeddings before entering the unification module. + +3. **Channel–Modality Unification Module** + Cross-attention aggregates information across both channels and modalities through a shared set of latent queries. This design removes the requirement for paired multimodal recordings during pre-training and enables training on large-scale unimodal corpora. + +4. **Temporal Transformer Encoder** + The unified latent sequence is processed by a patch-wise temporal Transformer with **Rotary Positional Embeddings (RoPE)** to capture long-range temporal dependencies. Self-attention operates on the fixed-size latent representation, fully decoupled from electrode count and modality composition. + +5. **Decoding and Classifier Heads** + During pre-training, a reconstruction decoder attends to encoder outputs to recover masked signal patches in a channel-specific manner. During fine-tuning this decoder is discarded and replaced by a lightweight aggregation query that pools the encoder output into a single representation, fed to a classification head. Three adaptation strategies are supported: + - **Full Fine-tuning (FF)**: All 5.4M parameters updated. + - **Frozen Encoder (FE)**: Backbone fixed; only the classification head (~400k parameters) trained. + - **LoRA**: Low-rank matrices (rank 16, ~180k parameters, ~580k total) injected into selected Transformer layers. + +--- + +### Self-Supervised Learning (SSL) Objective + +PanLUNA is pre-trained with a **masked signal reconstruction** objective. A random subset of patch tokens is masked, and the reconstruction decoder is trained to recover the original signal patches in a channel-specific manner. + +--- + +### Classification Protocols + +- **BC – Binary Classification**: Window-level binary label (e.g., normal vs. abnormal EEG on TUAB). +- **MCC – Multi-class Classification**: Single-label classification per window (e.g., 5-stage sleep scoring on HMC). +- **Multi-label Classification**: Multiple co-occurring labels per window (e.g., 19-label PTB-XL-Form ECG morphology). + +--- + +### Model Variants + +| Variant | Parameters | +|---------|------------| +| PanLUNA | 5.4M | + +--- + +### Training Setup + +- **Pre-training** + - **Datasets**: ~40,000 hours of heterogeneous biosignal data across five corpora: + + | Dataset | Modality | Subjects | Channels | FS (Hz) | Window | + |---------|----------|----------|----------|---------|--------| + | TUEG | EEG | 14,987 | 20/22 | 250 | 5 s | + | Siena | EEG | 14 | 29 | 512 | 5 s | + | MIMIC-IV | ECG | 161,352 | 12 | 500 | 5 s | + | CODE-15% | ECG | 233,700 | 12 | 400 | 5 s | + | PulseDB | ECG, PPG | 5,361 | 2 | 125 | 5 s | + + - **Objective**: Masked signal reconstruction; each modality can be used independently (no paired multimodal data required). + +- **Fine-tuning** + - Reconstruction decoder replaced with aggregation query + classification head; three adaptation strategies available (FF, FE, LoRA). + - **Loss**: Cross-Entropy for multi-class; BCE for multi-label classification. + - **Dataset splits**: + - **TUAB**: Official predefined train/val/test split. + - **PTB-XL (Super/Sub/Form/Rhythm)** and **CSN**: MERL ICML 2024 protocol. + - **HMC (sleep staging)**: Splits as in PhysioOmni. + +- **Quantization** + - Post-Training Quantization (PTQ) and Quantization-Aware Training (QAT) via Brevitas; evaluated at INT8, INT4, and INT2 weights. QAT runs for 15 fine-tuning epochs and recovers ≥96% of FP32 performance at INT8; INT2 weights achieve up to 16× storage reduction with graceful degradation. + +--- + +### Edge Deployment (GAP9) + +PanLUNA is deployed on the **GAP9 ultra-low-power RISC-V microcontroller** (9-core cluster at 370 MHz, 1.5 MB L2 SRAM) using the BioFoundation edge framework with automated operator tiling, double-buffered DMA, NE16 acceleration, and custom tiled kernels for cross-attention projections and sensor-type embedding lookup. + +| Configuration | Channels | Window | MACs | Latency | Energy | Power | +|---------------|----------|--------|------|---------|--------|-------| +| ECG only | 12 | 10 s | 120.5 M | 325.6 ms | 18.8 mJ | 60.2 mW | +| EEG + ECG | 5 | 30 s | 446.2 M | 1.206 s | 68.65 mJ | 56.9 mW | + +Streaming latency for ECG (patch-triggered): **450.6 ms** (125 ms acquisition + 325.6 ms compute). Estimated continuous monitoring battery life on a 300 mAh / 3.7 V wearable: **~24 days** (ECG-only), **~20 days** (multimodal sleep staging). This is, to our knowledge, the first deployment of a multimodal physiological FM on an ultra-low-power MCU. + +--- + +### Results Summary + +**TUAB (Abnormal EEG Detection)** +- PanLUNA (FF): **81.21%** balanced accuracy, 0.8999 AUC-PR, 0.8932 AUROC — outperforming LUNA-Base and LUNA-Large despite being 8–57× smaller. + +**HMC (Multimodal Sleep Staging, 5-class)** + +| Variant | Test Modality | Bal. Acc. | Cohen's κ | Weighted F1 | +|---------|--------------|-----------|-----------|-------------| +| PanLUNA (FF) | EEG | **0.7416** | **0.6946** | **0.7659** | +| PanLUNA (FF) | EEG + ECG | 0.7002 | 0.6561 | 0.7383 | +| PanLUNA (FF) | ECG only | 0.2977 | 0.1095 | 0.2876 | +| PanLUNA (QAT INT8) | EEG + ECG | 0.7347 | 0.6913 | 0.7273 | + +State-of-the-art on HMC; surpasses PhysioOmni by +1.27% balanced accuracy. + +**PTB-XL / CSN (Cardiac Benchmarks, LoRA FE, FP32)** + +| Task | AUROC | +|------|-------| +| PTB-XL Super | 0.9083 | +| PTB-XL Sub | 0.8880 | +| PTB-XL Form | 0.8331 | +| PTB-XL Rhythm | 0.9641 | +| CSN | 0.9505 | + +State-of-the-art on PTB-XL Super and CSN. QAT INT8 recovers ≥96% of FP32 AUROC across all tasks; INT2 weights achieve up to 16× storage reduction with graceful degradation. \ No newline at end of file diff --git a/PanLUNA.py b/PanLUNA.py new file mode 100644 index 0000000..2a16e3d --- /dev/null +++ b/PanLUNA.py @@ -0,0 +1,311 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* + +import math +from functools import partial +from typing import List +import numpy as np +import random +import torch +import torch.distributed as dist +import torch.nn as nn +import torch.nn.functional as F + +from timm.models.layers import trunc_normal_ as __call_trunc_normal_ +from timm.models.layers import Mlp + +from einops import rearrange +from models.modules.rope_transformer_encoder_block import RotaryTransformerBlock +from models.modules.frequency_embedder import FrequencyFeatureEmbedder +from models.modules.lead_positions import ChannelEmbeddings, SensorEmbeddings + +def trunc_normal_(tensor, mean=0., std=1.): + __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std) + +def nerf_positional_encoding(coords: torch.Tensor, embed_size: int) -> torch.Tensor: + """ + coords: (N, C, 3) + Returns: (N, C, embed_size) + """ + N, C, dim = coords.shape + device = coords.device + freqs = embed_size // (2 * dim) + leftover = embed_size - freqs * 2 * dim + freq_bands = 2.0 ** torch.arange(freqs, device=device).float() + scaled_coords = coords.unsqueeze(-1) * freq_bands.view(1, 1, 1, -1) # (N, C, dim, freqs) + sin_enc = torch.sin(scaled_coords) # (N, C, dim, freqs) + cos_enc = torch.cos(scaled_coords) # (N, C, dim, freqs) + encoded = torch.stack([sin_enc, cos_enc], dim=-1).permute(0, 1, 3, 2, 4).reshape(N, C, freqs * dim * 2) + if leftover > 0: + pad = torch.zeros(N, C, leftover, device=device, dtype=coords.dtype) + encoded = torch.cat([encoded, pad], dim=-1) + return encoded + +class PatchReconstructionHeadWithQueries(nn.Module): + def __init__( + self, + input_dim: int = 8, + embed_dim: int = 768, + num_heads: int = 8, + num_queries: int = 4, + ): + super().__init__() + self.input_dim = input_dim + self.embed_dim = embed_dim + self.reconstruction_shape = self.input_dim + self.num_queries = num_queries + # Projection from embed space to pixel space, according to type of input + self.decoder_pred = nn.TransformerDecoder( + nn.TransformerDecoderLayer(embed_dim, num_heads, dropout=0.0, batch_first=True, activation='gelu', dim_feedforward=int(embed_dim*4), norm_first=True), + num_layers=1 + ) + self.norm = nn.LayerNorm(embed_dim) + self.decoder_linear = Mlp(embed_dim, int(embed_dim*4), input_dim, act_layer=nn.GELU, drop=0.0) #nn.Linear(embed_dim, input_dim, bias=True) + + def forward(self, enc, decoder_queries): + """ + enc: [B, num_patches, embed_dim], embed_dim = Q*D + decoder_queries: [B*num_patches, num_channels, embed_dim] + """ + + B, num_patches, embed_dim = enc.shape + enc = rearrange(enc, 'B t (Q D) -> (B t) Q D', Q=self.num_queries) + out = self.decoder_pred(decoder_queries, enc) # (B*t, C, D) + out = self.norm(out) + out = self.decoder_linear(out) # (B*t, C, patch_size) + out = rearrange(out, '(B t) C P -> B C (t P)', B=B) + return out + +class ClassificationHeadWithQueries(nn.Module): + def __init__( + self, + input_dim: int = 8, + embed_dim: int = 768, + num_queries: int = 8, + num_heads: int = 8, + num_classes: int = 2, + ): + super().__init__() + self.input_dim = input_dim + self.embed_dim = int(embed_dim*num_queries) + self.reconstruction_shape = self.input_dim + self.decoder_attn = nn.MultiheadAttention(self.embed_dim, num_heads, batch_first=True, dropout=0.15) + self.decoder_ffn = Mlp(in_features=self.embed_dim, hidden_features=int(2*self.embed_dim), out_features=num_classes, act_layer=nn.GELU, drop=0.15) + + self.learned_agg = nn.Parameter(torch.randn(1, 1, self.embed_dim), requires_grad=True) + + def forward(self, x): + """ + Output shape: + [B, num_tokens, in_chans, input_dim] + Args: + x: [B, num_tokens+1, embed_dim] + channel_embeddings: [B, in_chans, embed_dim] + """ + B, num_patches, embed_dim = x.shape + decoder_queries = self.learned_agg.repeat(x.shape[0], 1, 1) + + x = self.decoder_attn(query=decoder_queries, key=x, value=x)[0] + x = x[:,0,:] + x = self.decoder_ffn(x) + return x + +class CrossAttentionBlock(nn.Module): + def __init__(self, num_queries, input_embed_dim, output_embed_dim, num_heads, dropout_p=0.1, ff_dim=2048, pre_norm=True): + super(CrossAttentionBlock, self).__init__() + self.num_queries = num_queries + self.dropout_p = dropout_p + self.query_embed = nn.Parameter(torch.randn(1, num_queries, input_embed_dim), requires_grad=True) # Learnable queries + self.cross_attention = nn.MultiheadAttention(embed_dim=input_embed_dim, num_heads=num_heads, dropout=dropout_p,batch_first=True) + self.temparature = nn.Parameter(torch.tensor(1.0), requires_grad=False) + self.ffn = Mlp(input_embed_dim, ff_dim, output_embed_dim, act_layer=nn.GELU, drop=dropout_p, norm_layer=nn.LayerNorm) + self.keys_norm = nn.LayerNorm(input_embed_dim) + self.values_norm = nn.LayerNorm(input_embed_dim) + self.queries_norm = nn.LayerNorm(input_embed_dim) + self.query_self_attn = nn.TransformerEncoder(nn.TransformerEncoderLayer(input_embed_dim, nhead=num_heads, activation='gelu', dim_feedforward=ff_dim, batch_first=True, norm_first=True), num_layers=3) + + def initialize_weights(self): + torch.nn.init.orthogonal_(self.query_embed, gain=1.0) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + # we use xavier_uniform following official JAX ViT: + torch.nn.init.xavier_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def forward(self, x): + # x is the input with shape (batch_size*num_patches, num_channels, embed_dim) + batch_size, num_channels, _ = x.size() + queries = self.query_embed.repeat(batch_size,1,1) + queries = self.queries_norm(queries) + keys = self.keys_norm(x) + values = self.values_norm(x) + + attention_out, attention_scores = self.cross_attention(query=queries,key=keys,value=values) # Shape: (batch_size*num_patches, num_queries, embed_dim) + attention_out = self.ffn(attention_out) + attention_out + attention_out = self.query_self_attn(attention_out) + return attention_out, attention_scores # Shape: (batch_size*num_patches, num_queries, embed_dim) + +class PatchEmbedNetwork(nn.Module): + def __init__(self, embed_dim=64, patch_size=40): + super(PatchEmbedNetwork, self).__init__() + self.patch_size = patch_size + self.embed_dim = embed_dim + self.in_channels = 1 + self.out_channels = int(embed_dim//4) + self.groups = 4 + self.kernel_size = int(patch_size//2) + self.proj_in = nn.Sequential( + nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=(1, self.kernel_size-1), stride=(1, self.kernel_size//2), padding=(0, self.kernel_size//2-1)), + nn.GroupNorm(self.groups, self.out_channels), + nn.GELU(), + + nn.Conv2d(in_channels=self.out_channels, out_channels=self.out_channels, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)), + nn.GroupNorm(self.groups, self.out_channels), + nn.GELU(), + + nn.Conv2d(in_channels=self.out_channels, out_channels=self.out_channels, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)), + nn.GroupNorm(self.groups, self.out_channels), + nn.GELU(), + ) + def forward(self, x): + """ + x: (B, C, T) + output: (B, C*S, D) where S = T//patch_size, D = embed_dim + """ + x = rearrange(x, 'B C (S P) -> B (C S) P', P=self.patch_size) + x = x.unsqueeze(1) + x = self.proj_in(x) + x = rearrange(x, 'B E CS D -> B CS (D E)') + return x + +class PanLUNA(nn.Module): + """ + LUNA extension for multimodal processing. + """ + def __init__(self, patch_size=40, num_queries=4, + embed_dim=64, depth=8, num_heads=2, + mlp_ratio=4., norm_layer=nn.LayerNorm, + drop_path=0.0, num_classes=0): + super().__init__() + self.embed_dim = embed_dim + self.num_queries = num_queries + self.patch_size = patch_size + self.patch_embed_size = embed_dim + self.num_heads = num_heads + self.num_classes = num_classes + self.depth = depth + self.patch_embed = PatchEmbedNetwork(embed_dim=self.embed_dim, patch_size=patch_size) + self.freq_embed = FrequencyFeatureEmbedder(embed_dim=self.embed_dim, patch_size=patch_size) + self.sensor_embed = SensorEmbeddings(embed_dim=self.embed_dim) + self.channel_location_embedder = nn.Sequential( + Mlp(in_features=int(self.patch_embed_size), out_features=int(self.patch_embed_size), hidden_features=int(self.patch_embed_size*2), act_layer=nn.GELU, drop=0.0, norm_layer=nn.LayerNorm), + ) + self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + self.cross_attn = CrossAttentionBlock(num_queries=num_queries, input_embed_dim=self.embed_dim, output_embed_dim=self.embed_dim, num_heads=self.num_heads, ff_dim=int(mlp_ratio*self.embed_dim), pre_norm=True) + self.blocks = nn.ModuleList([ + RotaryTransformerBlock(dim=int(self.embed_dim*self.num_queries), num_heads=int(self.num_heads*self.num_queries), mlp_ratio=mlp_ratio, qkv_bias=True, drop=0.0, attn_drop=0.0, drop_path=drop_path, norm_layer=norm_layer) + for i in range(depth)]) + self.norm = norm_layer(int(self.embed_dim*self.num_queries)) + if num_classes==0: # reconstruction (pre-training) + self.decoder_head = PatchReconstructionHeadWithQueries(input_dim=patch_size, embed_dim=self.embed_dim, num_heads=self.num_heads, num_queries=num_queries) + self.channel_emb = ChannelEmbeddings(self.embed_dim) + else: # classification + self.classifier = ClassificationHeadWithQueries(input_dim=patch_size, num_queries=num_queries, embed_dim=self.embed_dim, num_classes=num_classes, num_heads=self.num_heads) + self.mask_token.requires_grad = False # no use of mask token for classification + self.initialize_weights() + + def initialize_weights(self): + self.cross_attn.initialize_weights() + trunc_normal_(self.mask_token, std=.02) + self.apply(self._init_weights) + self.fix_init_weight() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + torch.nn.init.xavier_normal_(m.weight) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + def fix_init_weight(self): + def rescale(param, layer_id): + param.div_(math.sqrt(2.0 * layer_id)) + + for layer_id, layer in enumerate(self.blocks): + rescale(layer.attn.proj.weight.data, layer_id + 1) + rescale(layer.mlp.fc2.weight.data, layer_id + 1) + + def prepare_tokens(self, x_signal, channel_locations, sensor_type, mask=None): + num_channels = channel_locations.shape[1] + num_patches_per_channel = x_signal.shape[-1] // self.patch_size + x_patched = self.patch_embed(x_signal) + freq_embed = self.freq_embed(x_signal) + x_patched = x_patched + freq_embed + x_masked = x_patched.clone() # (B, N, D), N = C * num_patches_per_channel + if mask is not None: + mask_tokens = self.mask_token.repeat(x_masked.shape[0], x_masked.shape[1], 1) # (B, N, D) N = C * num_patches_per_channel + mask = rearrange(mask, 'B C (S P) -> B (C S) P', P=self.patch_size) # (B, C, T) -> (B, N, P) + mask = (mask.sum(dim=-1) > 0).unsqueeze(-1).float() # (B, N, 1), since a patch is either fully masked or not + x_masked = torch.where(mask.bool(), mask_tokens, x_masked) + channel_min = torch.min(channel_locations, dim=1, keepdim=True)[0] + channel_max = torch.max(channel_locations, dim=1, keepdim=True)[0] + channel_locations = (channel_locations - channel_min) / (channel_max - channel_min + 1e-8) + if mask is not None: + channel_locations = channel_locations + torch.randn_like(channel_locations) * 0.02 + channel_locations = nerf_positional_encoding(channel_locations, self.patch_embed_size) + channel_locations_emb = self.channel_location_embedder(channel_locations) + + # Added part - sensor type + sensor_type = self.sensor_embed(sensor_type).repeat_interleave(num_patches_per_channel, dim=0) + x_tokenized = rearrange(x_masked, 'B (C t) D -> (B t) C D', C=num_channels) + channel_locations_emb = channel_locations_emb.repeat_interleave(num_patches_per_channel, dim=0) + x_tokenized = x_tokenized + channel_locations_emb + sensor_type + + return x_tokenized, channel_locations_emb, sensor_type + + def forward(self, x_signal, mask, channel_locations, sensor_type, channel_names=None, **kwargs): + x_original = x_signal + B, C, T = x_signal.shape + x, channel_locations_emb, sensor_type = self.prepare_tokens(x_signal, channel_locations, sensor_type, mask=mask) + + x, attention_scores = self.cross_attn(x) # (B*num_patches, Q, D) + x = rearrange(x, '(B t) Q D -> B t (Q D)', B=B) # (B, num_patches, Q*D), Q*D is the new embed_dim + num_patches = x.shape[1] + for blk in self.blocks: + x = blk(x) # (B, N, D) + x_latent = self.norm(x) # (B, N, D) + + if self.num_classes > 0: + x_classified = self.classifier(x_latent) + return x_classified, x_original + else: + channel_emb = self.channel_emb(channel_names) + channel_emb = channel_emb.repeat(num_patches, 1, 1) + decoder_queries = channel_locations_emb + channel_emb + sensor_type # (B*N, C, D) + x_reconstructed = self.decoder_head(x_latent, decoder_queries) + return x_reconstructed, x_original, attention_scores diff --git a/PanLUNA_finetune.yaml b/PanLUNA_finetune.yaml new file mode 100644 index 0000000..bd3ea35 --- /dev/null +++ b/PanLUNA_finetune.yaml @@ -0,0 +1,30 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +# @package _global_ +model: + _target_: models.PanLUNA.PanLUNA + patch_size: 32 + embed_dim: 64 + num_heads: 2 + depth: 6 + num_queries: 4 + mlp_ratio: 4 + drop_path: 0.1 + num_classes: 2 \ No newline at end of file diff --git a/PanLUNA_pretrain.yaml b/PanLUNA_pretrain.yaml new file mode 100644 index 0000000..94d7311 --- /dev/null +++ b/PanLUNA_pretrain.yaml @@ -0,0 +1,30 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +# @package _global_ +model: + _target_: models.PanLUNA.PanLUNA + patch_size: 32 + embed_dim: 64 + num_heads: 2 + depth: 6 + num_queries: 4 + mlp_ratio: 4 + drop_path: 0.0 + num_classes: 0 \ No newline at end of file diff --git a/README.md b/README.md index 5fe540a..4d71334 100644 --- a/README.md +++ b/README.md @@ -13,6 +13,9 @@ LuMamba Paper + + PanLUNA Paper + Hugging Face: FEMBA @@ -25,6 +28,9 @@ Hugging Face: LuMamba + + Hugging Face: LuMamba + GitHub Stars @@ -32,7 +38,7 @@ Copyright (C) 2025-2026 ETH Zurich, Switzerland. SPDX-License-Identifier: Apache-2.0. See LICENSE file for details. -Authors: Thorir Mar Ingolfsson, Anna Tegon, Berkay Döner, Xiaying Wang, Matteo Fasulo, Danaé Broustail, Yawei Li & Luca Benini. +Authors: Thorir Mar Ingolfsson, Anna Tegon, Berkay Döner, Xiaying Wang, Matteo Fasulo, Danaé Broustail, Marija Zelic, Yawei Li & Luca Benini. ## About @@ -50,6 +56,7 @@ Looking for ready-to-use weights of models? We host them on Hugging Face: - **LUNA** ([paper](https://arxiv.org/abs/2510.22257)) [![HF Model Card](https://img.shields.io/badge/Model%20Card-LUNA-ffcc4d?logo=huggingface&logoColor=black)](https://huggingface.co/PulpBio/LUNA) - **TinyMyo** ([paper](https://arxiv.org/abs/2512.15729)) [![HF Model Card](https://img.shields.io/badge/Model%20Card-TinyMyo-ffcc4d?logo=huggingface&logoColor=black)](https://huggingface.co/PulpBio/TinyMyo) - **LuMamba** ([paper](https://arxiv.org/abs/2603.19100)) [![HF Model Card](https://img.shields.io/badge/Model%20Card-LuMamba-ffcc4d?logo=huggingface&logoColor=black)](https://huggingface.co/PulpBio/LuMamba) +- **PanLUNA** ([paper](https://arxiv.org/pdf/2604.04297)) [![HF Model Card](https://img.shields.io/badge/Model%20Card-PanLUNA-ffcc4d?logo=huggingface&logoColor=black)](https://huggingface.co/PulpBio/PanLUNA) #### Why FEMBA? @@ -227,6 +234,46 @@ Tips: - Ensure `data_module:train/test/val` are initialized with the compatible dataset class. - Configuration file includes sufficient `#CHANGEME` tags and further instructions for a working example. +## Why PanLUNA? + +* Extending LUNA's channel-unification mechanism from topology invariance to **cross-modal fusion**, jointly processing EEG, ECG, and PPG within a single shared encoder via sensor-type embeddings - no modality-specific backbones, no paired multimodal data required during pretraining. +* Pretrained on ~40,000 hours of heterogeneous biosignal data (TUEG, Siena, MIMIC-IV, CODE-15%, PulseDB) with a masked signal reconstruction objective. +* Strong performance on unimodal and multimodal experiments. + +➡️ Model hub: __https://huggingface.co/PulpBio/PanLUNA__ 📄 Model card: __[PanLUNA on Hugging Face](https://huggingface.co/PulpBio/PanLUNA)__ — variants, configs, and fine-tuning walkthrough. 📜 Weights license: CC BY-ND 4.0 (use + redistribute unmodified weights with attribution; no redistribution of modified weights) 🧑‍🍳 PR-gated improvements: If you fine-tune internally and want your variant to become an official PanLUNA release, open a PR with configs, logs, and evals. We'll review; if it looks good, we'll retrain/validate and publish an official PanLUNA checkpoint. + +### What you'll find on the hub + +* Pre-trained checkpoint. +* Instructions to get started on fine-tuning experiments. + +### Quick download with `huggingface_hub`: + +``` +pip install huggingface_hub +``` + +```python +from huggingface_hub import snapshot_download + +# downloads all pre-trained variants and safetensors into ./checkpoints/PanLUNA +snapshot_download(repo_id="PulpBio/PanLUNA", repo_type="model", local_dir="checkpoints/PanLUNA") +``` + +Include the safetensors checkpoint path as input and run fine-tuning in the commandline: + +```bash +python -u run_train.py +experiment=PanLUNA_finetune \ + pretrained_safetensors_path=/absolute/path/to/checkpoints/PanLUNA/PanLUNA.safetensors +``` + +### Tips: + +* Data preprocessing scripts are provided in `/make_datasets` for various downstream datasets. Download the corresponding dataset, locate its preprocessing script by name matching, and adjust key parameters. +* Adapt configuration file `config/experiment/PanLUNA_finetune.yaml` to your specific task with correct data module (for unimodal experiments `config/data_module/finetune_data_module_unimodal_PanLUNA` or multimodal experiments `config/data_module/finetune_data_module_multimodal_PanLUNA`), classification type (binary `bc`, multi-class `mcc` and mulit-label `mlp`) and change `model.num_classes` accordingly. For different fine-tuning strategies adjust `finetuning.mode` parameter with `lora` for Low-Rank Adaptation, `freeze_encoder` for frozen backbone or select `full` for complete adaptation. + * Ensure `data_module:train/test/val` are initialized with the compatible dataset class. Leverage `config/data_module/finetune_data_module_multimodal_PanLUNA:data_module.train.channel_groups` to specify channels available in the dataset and `channel_start`/`channel_end` to tweak the used subset. + * Configuration file includes sufficient `#CHANGEME` tags and further instructions for a working example. + ## Features - **Modular Design**: The repository is organized into modules for data loading, models, training tasks, and more, making it easy to extend and adapt for new research projects. diff --git a/dataset_types.yaml b/dataset_types.yaml new file mode 100644 index 0000000..34c5e58 --- /dev/null +++ b/dataset_types.yaml @@ -0,0 +1,33 @@ +dataset_types: # both pretraining datasets and unimodal finetuning datasets + tueg: + channels: [FP1-F7, F7-T3, T3-T5, T5-O1, FP2-F8, F8-T4, T4-T6, T6-O2, T3-C3, C3-CZ, CZ-C4, C4-T4, FP1-F3, F3-C3, C3-P3, P3-O1, FP2-F4, F4-C4, C4-P4, P4-O2, A1-T3, T4-A2] + location_fn: "eeg" + sensor_type: 1 + siena: + channels: [FP1, FP2, F3, C3, P3, O1, F7, T3, T5, FC1, FC5, CP1, CP5, F9, FZ, CZ, PZ, F4, C4, P4, O2, F8, T4, T6, FC2,FC6, CP2, CP6, F10] + location_fn: "eeg" + sensor_type: 1 + code15: + channels: [I, II, III, AVR, AVL, AVF, V1, V2, V3, V4, V5, V6] + location_fn: "ecg" + sensor_type: 0 + mimicIV: + channels: [I, II, III, AVR, AVF, AVL, V1, V2, V3, V4, V5, V6] + location_fn: "ecg" + sensor_type: 0 + pulsedb: + channels: [II, PPG] + location_fn: "ecg" + sensor_type: [0, 2] + ptbxl: + channels: [I, II, III, AVR, AVF, AVL, V1, V2, V3, V4, V5, V6] + location_fn: "ecg" + sensor_type: 0 + csn: + channels: [I, II, III, AVR, AVL, AVF, V1, V2, V3, V4, V5, V6] + location_fn: "ecg" + sensor_type: 0 + cpsc2018: + channels: [I, II, III, AVR, AVL, AVF, V1, V2, V3, V4, V5, V6] + location_fn: "ecg" + sensor_type: 0 \ No newline at end of file diff --git a/finetune_data_module_multimodal_PanLUNA.yaml b/finetune_data_module_multimodal_PanLUNA.yaml new file mode 100644 index 0000000..1f52129 --- /dev/null +++ b/finetune_data_module_multimodal_PanLUNA.yaml @@ -0,0 +1,109 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +# @package _global_ +# This is example for SEED-VII dataset +# To adjust for other multimodal datasets replace channel_groups field with appropriate modalities and their channels +# Adjust slicing with channel start and end indexes +data_module: + _target_: data_module.finetune_data_module.FinetuneDataModule + name: "eeg" + cfg: + num_workers: ${num_workers} + batch_size: ${batch_size} + train: + _target_: 'datasets.finetuning_multimodal_datasets_PanLUNA.FinetuningMultimodal_Dataset' + hdf5_file: "#CHANGEME" + <<: &seed_channels + channel_groups: + eeg: + - FP1 + - FPZ + - FP2 + - AF3 + - AF4 + - F7 + - F5 + - F3 + - F1 + - FZ + - F2 + - F4 + - F6 + - F8 + - FT7 + - FC5 + - FC3 + - FC1 + - FCZ + - FC2 + - FC4 + - FC6 + - FT8 + - T7 + - C5 + - C3 + - C1 + - CZ + - C2 + - C4 + - C6 + - T8 + - TP7 + - CP5 + - CP3 + - CP1 + - CPZ + - CP2 + - CP4 + - CP6 + - TP8 + - P7 + - P5 + - P3 + - P1 + - PZ + - P2 + - P4 + - P6 + - P8 + - PO7 + - PO5 + - PO3 + - POZ + - PO4 + - PO6 + - PO8 + - CB1 + - O1 + - OZ + - O2 + - CB2 + ecg: + - II + channel_start: 0 # change for fraction of channels + channel_end: 63 + val: + _target_: 'datasets.finetuning_multimodal_datasets_PanLUNA.FinetuningMultimodal_Dataset' + hdf5_file: "#CHANGEME" + <<: *seed_channels + test: + _target_: 'datasets.finetuning_multimodal_datasets_PanLUNA.FinetuningMultimodal_Dataset' + hdf5_file: "#CHANGEME" + <<: *seed_channels \ No newline at end of file diff --git a/finetune_data_module_unimodal_PanLUNA.yaml b/finetune_data_module_unimodal_PanLUNA.yaml new file mode 100644 index 0000000..f6ccf55 --- /dev/null +++ b/finetune_data_module_unimodal_PanLUNA.yaml @@ -0,0 +1,49 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +# @package _global_ +defaults: + - dataset_types + +data_module: + _target_: data_module.finetune_data_module.FinetuneDataModule + cfg: + num_workers: ${num_workers} + batch_size: ${batch_size} + train: + _target_: 'datasets.finetuning_unimodal_datasets_PanLUNA.FinetuningUnimodal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 # example + channels: ${dataset_types.csn.channels} # example + location_fn: ${dataset_types.csn.location_fn} + sensor_type: ${dataset_types.csn.sensor_type} + val: + _target_: 'datasets.finetuning_unimodal_datasets_PanLUNA.FinetuningUnimodal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.csn.channels} + location_fn: ${dataset_types.csn.location_fn} + sensor_type: ${dataset_types.csn.sensor_type} + test: + _target_: 'datasets.finetuning_unimodal_datasets_PanLUNA.FinetuningUnimodal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.csn.channels} + location_fn: ${dataset_types.csn.location_fn} + sensor_type: ${dataset_types.csn.sensor_type} \ No newline at end of file diff --git a/finetune_task_PanLUNA.py b/finetune_task_PanLUNA.py new file mode 100644 index 0000000..050b945 --- /dev/null +++ b/finetune_task_PanLUNA.py @@ -0,0 +1,419 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* + +import torch +import torch.nn as nn +import pytorch_lightning as pl +import hydra +import torch_optimizer as torch_optim +import torch.nn.functional as F +from torchmetrics import MetricCollection +from torchmetrics.classification import ( + Accuracy, Precision, Recall, AUROC, + AveragePrecision, CohenKappa, F1Score +) +from peft import LoraConfig, get_peft_model, TaskType + +class ChannelWiseNormalize: + def __init__(self, eps=1e-8): + self.eps = eps + + def __call__(self, tensor): + with torch.no_grad(): + # tensor: (B, C, T) + mean = tensor.mean(dim=2, keepdim=True) + std = tensor.std(dim=2, keepdim=True) + return (tensor - mean) / (std + self.eps) + +class FinetuneTask(pl.LightningModule): + """ + PyTorch Lightning module for fine-tuning a classification model, with support for: + + - Classification types: + - `bc`: Binary Classification + - `ml`: Multi-Label Classification + - 'mc': Multi-Label Classification for TUAR + - `mcc`: Multi-Class Classification + - `mmc`: Multi-Class Multi-Output Classification + - `mlp`: Multi-Label Classification for ECG Tasks, e.g. PTB-XL Super (target of each sample of shape (12, 2560) is [0, 1, 0, 0, 1] having 5 possible labels for each sample and possibility of multiple labels per sample) + + - Different fine-tuning strategies available: LoRA, frozen encoder and full. + - Metric logging during training, validation, and testing, including accuracy, precision, recall, F1 score, AUROC, and more + - Optional input normalization with configurable normalization functions + - Custom optimizer support including SGD, Adam, AdamW, and LAMB + - Learning rate schedulers with configurable scheduling strategies + - Layer-wise learning rate decay for fine-grained learning rate control across model blocks + """ + def __init__(self, hparams): + """ + Initialize the FinetuneTask module. + + Args: + hparams (DictConfig): Hyperparameters and configuration loaded via Hydra. + """ + super().__init__() + self.save_hyperparameters(hparams) + self.model = hydra.utils.instantiate(self.hparams.model) + + print("--- MODEL MODULE NAMES ---") + for name, module in self.model.named_modules(): + # Print only the linear or attention-related modules to keep the list short + if any(keyword in name for keyword in ['qkv', 'proj', 'attn', 'cross_attention']): + print(f"'{name}'") + + self.num_classes = self.hparams.model.num_classes + self.classification_type = self.hparams.classification_type + + # Input normalization + if self.hparams.input_normalization is not None and self.hparams.input_normalization.normalize: + self.normalize = True + self.normalize_fct = ChannelWiseNormalize() + + # Loss function + if self.classification_type == "mc": + self.criterion = nn.BCEWithLogitsLoss() + elif self.classification_type == "mlp": + self.criterion = nn.BCEWithLogitsLoss() + else: + self.criterion = nn.CrossEntropyLoss() + + # Classification mode detection + if not isinstance(self.num_classes, int): + raise TypeError("Number of classes must be an integer.") + elif self.num_classes < 2: + raise ValueError("Number of classes must be at least 1.") + elif self.num_classes == 2: + self.classification_task = "binary" + else: + self.classification_task = "multiclass" + + # Metrics + if self.classification_type != "mlp": + label_metrics = MetricCollection([ + Accuracy(task=self.classification_task, num_classes=self.num_classes, average="macro"), + Recall(task=self.classification_task, num_classes=self.num_classes, average="macro"), + Precision(task=self.classification_task, num_classes=self.num_classes, average="macro"), + F1Score(task=self.classification_task, num_classes=self.num_classes, average="macro"), + CohenKappa(task=self.classification_task, num_classes=self.num_classes) + ]) + logit_metrics = MetricCollection([ + AUROC(task=self.classification_task, num_classes=self.num_classes, average="macro"), + AveragePrecision(task=self.classification_task, num_classes=self.num_classes, average="macro"), + ]) + elif self.classification_type == 'mlp': + label_metrics = MetricCollection([ + Accuracy(task='multilabel', num_labels=self.num_classes, average="macro"), + Recall(task='multilabel', num_labels=self.num_classes, average="macro", zero_division=0), + Precision(task='multilabel', num_labels=self.num_classes, average="macro", zero_division=0), + F1Score(task='multilabel', num_labels=self.num_classes, average="macro", zero_division=0), + ]) + + logit_metrics = MetricCollection([ + AUROC(task='multilabel', num_labels=self.num_classes, average="macro"), + AveragePrecision(task='multilabel', num_labels=self.num_classes, average="macro"), + ]) + + self.train_label_metrics = label_metrics.clone(prefix='train_') + self.val_label_metrics = label_metrics.clone(prefix='val_') + self.test_label_metrics = label_metrics.clone(prefix='test_') + self.train_logit_metrics = logit_metrics.clone(prefix='train_') + self.val_logit_metrics = logit_metrics.clone(prefix='val_') + self.test_logit_metrics = logit_metrics.clone(prefix='test_') + + def load_pretrained_checkpoint(self, model_ckpt): + """ + Load a pretrained model checkpoint and unfreeze specific layers for fine-tuning. + """ + if model_ckpt is not None: + assert self.model.classifier is not None + print("Loading pretrained checkpoint") + ckpt = torch.load(model_ckpt) + state_dict = ckpt['state_dict'] + + # Remove decoder head and channel embedding weights since they are not needed for fine-tuning + state_dict = {k: v for k, v in state_dict.items() if 'decoder_head' not in k and "channel_emb" not in k} + + new_state_dict = {} + for k, v in state_dict.items(): + if k.startswith("model."): + new_state_dict[k[len("model."):]] = v + else: + new_state_dict[k] = v + + missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False) + print("missing:", missing) + print("unexpected:", unexpected) + + if self.hparams.finetuning.mode == "lora": + print("=> Applying LoRA strategy...") + lora_cfg = self.hparams.finetuning.lora + + config = LoraConfig( + task_type=TaskType.FEATURE_EXTRACTION, + r=lora_cfg.r, + lora_alpha=lora_cfg.alpha, + lora_dropout=lora_cfg.dropout, + target_modules=lora_cfg.target_modules + ) + + self.model = get_peft_model(self.model, config) + self.model.print_trainable_parameters() + + for p in self.model.base_model.model.classifier.parameters(): + p.requires_grad=True + + elif self.hparams.finetuning.mode == "freeze_encoder": + print("=> Applying Freeze Encoder strategy...") + + for name, param in self.model.named_parameters(): + param.requires_grad = False + if 'classifier' in name: + param.requires_grad = True + + else: + print("=> Applying Full Fine-tuning strategy...") + + print("Pretrained model ready.") + + def generate_fake_mask(self, batch_size, C, T): + """ + Create a dummy mask tensor to simulate attention masking. + + Args: + batch_size (int): Number of samples. + C (int): Number of channels. + T (int): Temporal dimension. + + Returns: + torch.Tensor: Boolean mask tensor of shape (B, C, T). + """ + return torch.zeros(batch_size, C, T, dtype=torch.bool).to(self.device) + + def _step(self, X, mask, channel_locations, channel_names, sensor_type): + """ + Perform forward pass and post-process predictions. + + Args: + X (torch.Tensor): Input tensor. + mask (torch.Tensor): Attention mask tensor. + + Returns: + dict: Dictionary containing predicted labels, probabilities, and logits. + """ + y_pred_logits, _ = self.model(x_signal=X, mask=mask, channel_locations=channel_locations, sensor_type=sensor_type, channel_names=channel_names) + + if self.classification_type in ("bc", "mcc", "ml"): + y_pred_probs = torch.softmax(y_pred_logits, dim=1) + y_pred_label = torch.argmax(y_pred_probs, dim=1) + elif self.classification_type in ("mc", "mlp"): + y_pred_probs = torch.sigmoid(y_pred_logits) + y_pred_label = torch.round(y_pred_probs) + elif self.classification_type == "mmc": + y_pred_logits = y_pred_logits.view(-1, 6) + y_pred_probs = torch.sigmoid(y_pred_logits) + y_pred_label = torch.argmax(y_pred_probs, dim=-1) + + return { + 'label': y_pred_label, + 'probs': y_pred_probs, + 'logits': y_pred_logits, + } + + def training_step(self, batch, batch_idx): + X, y = batch["input"], batch["label"] + + channel_locations = batch["channel_locations"] + channel_names = batch.get("channel_names", None) + sensor_type = batch["sensor_type"] + + if self.normalize: + X = self.normalize_fct(X) + + mask = None + y_pred = self._step(X, mask, channel_locations, channel_names, sensor_type) + if self.classification_type == "mmc": + loss = self.criterion(y_pred['logits'], y) + elif self.classification_type in ("mc", "mlp"): + loss = self.criterion(y_pred['logits'], y.float()) + else: + loss = self.criterion(y_pred['logits'], y) + + self.train_label_metrics(y_pred['label'], y) + self.train_logit_metrics(self._handle_binary(y_pred['logits']), y) + self.log_dict(self.train_label_metrics, on_step=True, on_epoch=False) + self.log_dict(self.train_logit_metrics, on_step=True, on_epoch=False) + self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + lr = self.optimizers().param_groups[0]['lr'] + self.log('lr', lr, on_step=True, on_epoch=False, prog_bar=False, logger=True) + return loss + + def validation_step(self, batch, batch_idx): + X, y = batch["input"], batch["label"] + + channel_locations = batch["channel_locations"] + channel_names = batch.get("channel_names", None) + sensor_type = batch["sensor_type"] + + if self.normalize: + X = self.normalize_fct(X) + + mask = None + y_pred = self._step(X, mask, channel_locations, channel_names, sensor_type) + if self.classification_type == "mmc": + y = y.view(-1) + loss = self.criterion(y_pred['logits'], y) + elif self.classification_type in ("mc", "mlp"): + loss = self.criterion(y_pred['logits'], y.float()) + else: + loss = self.criterion(y_pred['logits'], y) + + self.val_label_metrics(y_pred['label'], y) + self.val_logit_metrics(self._handle_binary(y_pred['logits']), y) + self.log_dict(self.val_label_metrics, on_step=False, on_epoch=True) + self.log_dict(self.val_logit_metrics, on_step=False, on_epoch=True) + self.log('val_loss', loss, prog_bar=True, logger=True, sync_dist=True) + return loss + + def test_step(self, batch, batch_idx): + X, y = batch["input"], batch["label"] + + channel_locations = batch["channel_locations"] + channel_names = batch.get("channel_names", None) + sensor_type = batch["sensor_type"] + + if self.normalize: + X = self.normalize_fct(X) + + mask = None + y_pred = self._step(X, mask, channel_locations, channel_names, sensor_type) + if self.classification_type == "mmc": + y = y.view(-1) + loss = self.criterion(y_pred['logits'], y) + elif self.classification_type in ("mc", "mlp"): + loss = self.criterion(y_pred['logits'], y.float()) + else: + loss = self.criterion(y_pred['logits'], y) + + self.test_label_metrics(y_pred['label'], y) + self.test_logit_metrics(self._handle_binary(y_pred['logits']), y) + self.log_dict(self.test_label_metrics, on_step=False, on_epoch=True) + self.log_dict(self.test_logit_metrics, on_step=False, on_epoch=True) + self.log('test_loss', loss, prog_bar=True, logger=True, sync_dist=True) + return loss + + def on_train_epoch_end(self): + self.log_dict(self.train_label_metrics, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True) + self.log_dict(self.train_logit_metrics, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True) + + def on_validation_epoch_end(self): + self.log_dict(self.val_label_metrics, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True) + self.log_dict(self.val_logit_metrics, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True) + + def on_test_epoch_end(self): + self.log_dict(self.test_label_metrics, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True) + self.log_dict(self.test_logit_metrics, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True) + + + def lr_scheduler_step(self, scheduler, metric): + """ + Custom scheduler step function for step-based LR schedulers + """ + if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): + scheduler.step(metric) + else: + scheduler.step_update(num_updates=self.global_step) + + def configure_optimizers(self): + """ + Configure the optimizer and learning rate scheduler. + + Returns: + dict: Configuration dictionary with optimizer and LR scheduler. + """ + num_blocks = self.hparams.model.depth + params_to_pass = [] + base_lr = self.hparams.optimizer.lr + decay_factor = self.hparams.layerwise_lr_decay + + for name, param in self.model.named_parameters(): + lr = base_lr + if self.hparams.finetuning.mode == "full" or self.hparams.finetuning.mode == "freeze_encoder": + if "block." in name or 'norm_layers' in name: + block_nr = int(name.split('.')[1]) + lr *= decay_factor ** (num_blocks - block_nr) + params_to_pass.append({"params": param, "lr": lr}) + else: + if 'norm_layers' in name: + block_nr = int(name.split('.')[1]) + lr *= decay_factor ** (num_blocks - block_nr) + params_to_pass.append({"params": param, "lr": lr}) + + if self.hparams.optimizer.optim == "SGD": + optimizer = torch.optim.SGD(params_to_pass, lr=base_lr, momentum=self.hparams.optimizer.momentum) + elif self.hparams.optimizer.optim == 'Adam': + optimizer = torch.optim.Adam(params_to_pass, lr=base_lr, weight_decay=self.hparams.optimizer.weight_decay) + elif self.hparams.optimizer.optim == 'AdamW': + optimizer = torch.optim.AdamW(params_to_pass, lr=base_lr, weight_decay=self.hparams.optimizer.weight_decay, betas=self.hparams.optimizer.betas) + elif self.hparams.optimizer.optim == 'LAMB': + optimizer = torch_optim.Lamb(params_to_pass, lr=base_lr) + else: + raise NotImplementedError("No valid optimizer name") + + if self.hparams.scheduler_type == "multi_step_lr": + scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer) + lr_scheduler_config = { + "scheduler": scheduler, + "interval": "step", + "frequency": 1 + } + elif self.hparams.scheduler_type == 'cosine': + scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer, + total_training_opt_steps=self.trainer.estimated_stepping_batches) + lr_scheduler_config = { + "scheduler": scheduler, + "interval": "step", + "frequency": 1 + } + else: + scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer, total_training_opt_steps= self.trainer.estimated_stepping_batches) + lr_scheduler_config = { + "scheduler": scheduler, + "interval": "step", + "frequency": 1 + } + + + return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} + + def _handle_binary(self, preds): + """ + Special handling for binary classification probabilities. + + Args: + preds (torch.Tensor): Logit outputs. + + Returns: + torch.Tensor: Probabilities for the positive class. + """ + if self.classification_task == 'binary' and self.classification_type != 'mc': + return preds[:, 1].squeeze() + else: + return preds diff --git a/finetune_task_PanLUNA.yaml b/finetune_task_PanLUNA.yaml new file mode 100644 index 0000000..5869f2e --- /dev/null +++ b/finetune_task_PanLUNA.yaml @@ -0,0 +1,22 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +# @package _global_ +task: + _target_: 'tasks.finetune_task_PanLUNA.FinetuneTask' \ No newline at end of file diff --git a/finetuning_multimodal_datasets_PanLUNA.py b/finetuning_multimodal_datasets_PanLUNA.py new file mode 100644 index 0000000..5a1ec2a --- /dev/null +++ b/finetuning_multimodal_datasets_PanLUNA.py @@ -0,0 +1,126 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +import torch +import h5py +import numpy as np +from typing import Optional + +from models.modules.lead_positions import ( + map_lead_labels_to_angles, + get_channel_indices, + get_channel_locations +) + +MODALITY_TO_SENSOR_TYPE = {"ecg": 0, "eeg": 1, "ppg": 2} +KNOWN_MODALITIES = {"eeg", "ecg", "ppg"} + +def _compute_channel_location( + channel_names: list[str], + channel_modalities: list[str] +): + """ + Build a (n_channels, max_dim) location array. + Handles mixed EEG/ECG by zero-padding to the largest feature dimension. + """ + raw_locs = [] + for ch, mod in zip(channel_names, channel_modalities): + if mod == "eeg": + loc = np.stack(get_channel_locations([ch]), axis=0) + elif mod == "ecg" or mod == "ppg": + loc = map_lead_labels_to_angles([ch]) + raw_locs.append(loc) + + max_dim = max(loc.shape[1] for loc in raw_locs) + padded = [ + np.pad(loc, ((0, 0), (0, max_dim - loc.shape[1])), mode="constant") + if loc.shape[1] < max_dim else loc + for loc in raw_locs + ] + + return np.vstack(padded) + +class FinetuningMultimodal_Dataset(torch.utils.data.Dataset): + """ + Unified class for mulitmodal finetuning datasets. Handles channel selection and channel location padding with support of hydra configuration. + + Args: + hadf5_file: Path to the .h5 file. + channel_groups: All channels in the dataset organized as in the finetune_data_module_multimodal.yaml + channel_start: Starting index for channel slicing. Allows taking all or fraction of modalities. + channel_end: Ending index for channel slicing. + """ + def __init__( + self, + hdf5_file: str, + channel_groups: dict[str, list[str]], + channel_start: Optional[int] = None, + channel_end: Optional[int] = None, + ): + super().__init__() + self._x_slice = slice(channel_start, channel_end) + + # Flatten channel_groups into parallel list of names and modalities + self.channel_names = [] + flat_modalities = [] + + for modality, names in channel_groups.items(): + self.channel_names.extend(names) + if modality in KNOWN_MODALITIES: + flat_modalities.extend([modality] * len(names)) + + self.channel_indices = torch.tensor(get_channel_indices(self.channel_names), dtype=torch.long) + + # Obtains channel locations + locs = _compute_channel_location(self.channel_names, flat_modalities) + self.channel_locations = torch.from_numpy(locs).float() + self.sensor_type = torch.tensor([MODALITY_TO_SENSOR_TYPE[m] for m in flat_modalities], dtype=torch.long) + + # Open HDF5 and build flat + self.data = h5py.File(hdf5_file, "r") + self.keys = list(self.data.keys()) + self.index_map = [] + + for key in self.keys: + group_size = len(self.data[key]['X']) + self.index_map.extend([(key, i) for i in range(group_size)]) + + def __len__(self): + return len(self.index_map) + + def __getitem__(self, index): + + group_key, sample_idx = self.index_map[index] + grp = self.data[group_key] + + X = torch.FloatTensor(grp["X"][sample_idx])[self._x_slice, :] + label = torch.tensor(grp["y"][sample_idx], dtype=torch.long) + + return_dict = { + "input": X, + "channel_names": self.channel_indices[self._x_slice], + "channel_locations": self.channel_locations[self._x_slice], + "sensor_type": self.sensor_type[self._x_slice], + "label": label + } + + return return_dict + + + \ No newline at end of file diff --git a/finetuning_unimodal_datasets_PanLUNA.py b/finetuning_unimodal_datasets_PanLUNA.py new file mode 100644 index 0000000..19dd1a0 --- /dev/null +++ b/finetuning_unimodal_datasets_PanLUNA.py @@ -0,0 +1,97 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +import torch +import h5py +import numpy as np +from models.modules.lead_positions import ( + get_channel_indices, + get_channel_locations, + map_lead_labels_to_angles, +) + +class FinetuningUnimodal_Dataset(torch.utils.data.Dataset): + """ + Unified class for unimodal finetuning datasets. + + Args: + hdf5_file: Path to the .h5 file. + channels: List of channels. + location_fn: "eeg" or "ecg". + sensor_type: 0 (ECG), 1(EEG), 2(PPG). + num_channels: Number of channels taken from total channel list. + """ + def __init__( + self, + hdf5_file: str, + channels: list[str], + location_fn: str, + sensor_type: int | list[int], + num_channels: int | None = None, + ): + super().__init__() + channel_names = channels[:num_channels] if num_channels else channels + self.num_channels = len(channel_names) + + if location_fn == "eeg": + locs = np.stack(get_channel_locations(channel_names), axis=0) + self.channel_locations = torch.from_numpy(locs).float() + else: + self.channel_locations = torch.FloatTensor(map_lead_labels_to_angles(channel_names)) + + self.channel_indices = torch.tensor(get_channel_indices(channel_names), dtype=torch.long) + + if isinstance(sensor_type, list): + self.sensor_type = torch.tensor(sensor_type, dtype=torch.long) + else: + self.sensor_type = torch.full((self.num_channels, ), sensor_type, dtype=torch.long) + + self.data = h5py.File(hdf5_file, "r") + self.keys = list(self.data.keys()) + self.index_map = [] + + for key in self.keys: + if key == 'index_map': + continue + group_size = len(self.data[key]["X"]) + self.index_map.extend([(key, i) for i in range(group_size)]) + + def __len__(self): + return len(self.index_map) + + def __getitem__(self, index): + + group_key, sample_idx = self.index_map[index] + grp = self.data[group_key] + X = torch.FloatTensor(grp["X"][sample_idx]) + label = torch.tensor(grp["y"][sample_idx], dtype=torch.long) + + item = { + "input": X, + "channel_names": self.channel_indices, + "channel_locations": self.channel_locations, + "sensor_type": self.sensor_type, + "label": label + } + + return item + + def __del__(self): + if hasattr(self, "data"): + self.data.close() \ No newline at end of file diff --git a/lead_positions.py b/lead_positions.py new file mode 100644 index 0000000..cfe221b --- /dev/null +++ b/lead_positions.py @@ -0,0 +1,212 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* + + + +import numpy as np +import torch +import torch.nn as nn +import mne + +SEED_VII_CH_LIST = ['FP1', 'FPZ', 'FP2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'OZ', 'O2', 'CB2', 'ECG'] + +MIMIC_IV_ECG_CHANNEL_LIST = [ + 'I', + 'II', + 'III', + 'AVR', + 'AVF', + 'AVL', + 'V1', + 'V2', + 'V3', + 'V4', + 'V5', + 'V6' +] + +SIENA_CHANNEL_LIST = [ + "FP1", + "FP2", + "F3", + "C3", + "P3", + "O1", + "F7", + "T3", + "T5", + "FC1", + "FC5", + "CP1", + "CP5", + "F9", + "FZ", + "CZ", + "PZ", + "F4", + "C4", + "P4", + "O2", + "F8", + "T4", + "T6", + "FC2", + "FC6", + "CP2", + "CP6", + "F10", +] + +TUEG_CHANNEL_LIST = [ + "FP1-F7", + "F7-T3", + "T3-T5", + "T5-O1", + "FP2-F8", + "F8-T4", + "T4-T6", + "T6-O2", + "T3-C3", + "C3-CZ", + "CZ-C4", + "C4-T4", + "FP1-F3", + "F3-C3", + "C3-P3", + "P3-O1", + "FP2-F4", + "F4-C4", + "C4-P4", + "P4-O2", + "A1-T3", + "T4-A2", +] + +VITALDB_CHANNEL_LIST = [ + 'II', + 'PPG' +] + +HMC_CHANNEL_LIST = [ + "F4", + "C4", + "O2", + "C3", +] + +all_channels = set() + +for ds in [ + SEED_VII_CH_LIST, + MIMIC_IV_ECG_CHANNEL_LIST, + TUEG_CHANNEL_LIST, + SIENA_CHANNEL_LIST, + VITALDB_CHANNEL_LIST, + HMC_CHANNEL_LIST +]: + for ch in ds: + all_channels.add(ch) +CHANNEL_NAMES_TO_IDX = {ch: i for i, ch in enumerate(sorted(all_channels))} +CHANNEL_IDX_TO_NAMES = {i: ch for ch, i in CHANNEL_NAMES_TO_IDX.items()} + +def get_channel_indices(channel_names): + indices = [] + for name in channel_names: + indices.append(CHANNEL_NAMES_TO_IDX[name]) + return indices + +def get_channel_names(channel_indices): + names = [] + for idx in channel_indices: + names.append(CHANNEL_IDX_TO_NAMES[idx]) + return names + +def get_channel_locations(channel_names): + if "-" in channel_names[0]: + names = list(set([part for ch in channel_names for part in ch.split('-')])) + else: + names = channel_names + ch_types = ['eeg'] * len(names) # Channel types + info = mne.create_info(ch_names=names, sfreq=256, ch_types=ch_types) + info = info.set_montage(mne.channels.make_standard_montage("standard_1005"),match_case=False,match_alias={'cb1': 'POO7', 'cb2': 'POO8'}) + locs = [] + for name in channel_names: + if name in TUEG_CHANNEL_LIST: + electrode1, electrode2 = name.split('-') + loc1 = info.get_montage().get_positions()['ch_pos'][electrode1] + loc2 = info.get_montage().get_positions()['ch_pos'][electrode2] + locs.append(((loc1 + loc2) / 2)) + else: + locs.append(info.get_montage().get_positions()['ch_pos'][name]) + return locs + +# We encode each ECG lead with 2-angle-tuple vector from https://www.ijcai.org/proceedings/2021/0495.pdf +lead_labels_mapping = { + 'I': (np.pi/2, np.pi/2), + 'II': (5*np.pi/6, np.pi/2), + 'III': (5*np.pi/6, -np.pi/2), + 'AVR': (np.pi/3, -np.pi/2), + 'AVL': (np.pi/3, np.pi/2), + 'AVF': (np.pi, np.pi/2), + 'V1': (np.pi/2, -np.pi/18), + 'V2': (np.pi/2, np.pi/18), + 'V3': (np.pi/2, np.pi/18), + 'V4': (11*np.pi/20, np.pi/6), + 'V5': (8*np.pi/15, np.pi/3), + 'V6': (8*np.pi/15, np.pi/2), + 'PPG': (0.0, 0.0) +} + +def map_lead_labels_to_angles(lead_labels): + """ + Mapping lead labels to angles. + + Args: + lead_labels (list): List of lead labels. + + Returns: + mapping (np.array): Array of mapped leads to angles for specific lead list. + """ + lead_labels_decoded = [lead.decode('utf-8') if isinstance(lead, bytes) else lead for lead in lead_labels] + mapping = np.array([lead_labels_mapping[lead] for lead in lead_labels_decoded], dtype=float) + + return mapping + +class ChannelEmbeddings(nn.Module): + def __init__(self, embed_dim): + super(ChannelEmbeddings, self).__init__() + self.embeddings = nn.Embedding(len(CHANNEL_NAMES_TO_IDX), embed_dim) + + def forward(self, indices): + return self.embeddings(indices) + + def initialize_weights(self): + torch.init.normal_(self.embeddings.weight, std=0.02) + +class SensorEmbeddings(nn.Module): + def __init__(self, embed_dim): + super(SensorEmbeddings, self).__init__() + self.embeddings = nn.Embedding(3, embed_dim) + + def forward(self, indices): + return self.embeddings(indices) + + def initialize_weights(self): + torch.init.normal_(self.embeddings.weight, std=0.02) \ No newline at end of file diff --git a/make_code15_dataset.py b/make_code15_dataset.py new file mode 100644 index 0000000..240b4eb --- /dev/null +++ b/make_code15_dataset.py @@ -0,0 +1,146 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +import h5py +import argparse +import multiprocessing +import tqdm +import os + +import numpy as np +from process_raw_ecg import preprocess_signal, time_segmenting + +sampling_freq = 400 +CODE15_channels = ['I', "II", "III", "AVR", "AVL", "AVF", "V1", "V2", "V3", "V4", "V5", "V6"] + +def dump_data(data_slice, output_dir, data_group_name, file_idx): + """ + Write the data in H5 file. + """ + + data_group = data_slice + output_path = os.path.join(output_dir, f"CODE15_{file_idx}.h5") + + with h5py.File(output_path, "a") as h5f: + + if h5f.attrs.get("channel_names") is None: + h5f.attrs['channel_names'] = CODE15_channels + + grp = h5f.create_group(data_group_name) + X_data = np.array(data_group, dtype=np.float16) + grp.create_dataset("X", data=X_data, dtype='float16') + + +def process_single_file(file_path, output_dir, downsample_fs, split_signal, file_idx, worker_id): + """ + Process a single HDF5 file of Code15 dataset. + + Each sample has shape (4096, 12). Samples are either of length 10s or 7s. + Sampling frequency is 400Hz. When length is 7s long, there's 2800 sample and to pad to 4096 zeros are addded. + In case of 10s, there's no additional padding, but it's actually a signal. + To have consistency, I first strip 96 samples, (48 at the beginning and 48 at the end) - it's going to be around 0.1s for full signals and just zeros for 7s long. + And then process it. + + Args: + file_path (str): Path to the file. + output_dir (str): Output directory to which we are writing the H5 files. + downsample_fs (int): Downsampling frequency. + split_signal (int): Lenght of the split to segment the signal. Assumes no padding, no overlap. + """ + + batch = 1024 + with h5py.File(file_path, "r") as f: + + # Extract dataset storing the ECGs + # We can't extract all the data in one np.array as it is too much + ecgs = f['tracings'] + N = ecgs.shape[0] + + # We need to do it in batches + for start in tqdm.tqdm(range(0, N, batch), desc=f"Worker {worker_id}", position=worker_id): + session_name = f"code15_{start}" + end = min(start+batch, N) + + # Extract batch_size samples and convert to np.array + ecg_batch = np.array(ecgs[start:end]) + + # Transpose and take middle 4000 samples (discarding first and last 48) + ecg_batch = np.transpose(ecg_batch, axes=(0,2,1))[:, :, 48:4048] + + # Process and split the signal + processed_ecg = preprocess_signal(ecg_batch, sampling_freq, 0.5, 120, downsample_fs, None) + reshaped_ecg = processed_ecg + + # Split if split_signal provided + if split_signal is not None: + + reshaped_ecg = np.expand_dims(processed_ecg, axis=0) + signal_splitted = time_segmenting(reshaped_ecg, split_signal, sampling_freq, downsample_fs, None) + merged_splits = np.concatenate(signal_splitted, axis=0) + reshaped_ecg = merged_splits.reshape(-1, merged_splits.shape[2], merged_splits.shape[3]) + + dump_data(reshaped_ecg, output_dir, session_name, file_idx) + +def loading_and_processing_parallel(files, output_dir, downsampling_fs, split_signal): + """ + Load and process all the HDF5 files from Code15 dataset in the input directory in parallel. + + Args: + files (list): List of HDF5 file paths to process. + output_dir (str): Directory to save processed pickle files. + downsampling_fs (int): Desired downsampling frequency. + split_signal (int or None): Length or segments to split the signals into (in seconds). If None, no splitting is done. + """ + num_workers = 18 + worker_args = [] + + # Prepare arguments for parallel processing + for idx, file in enumerate(files): + worker_args.append((file, output_dir, downsampling_fs, split_signal, idx)) + + print(f"Processing {len(worker_args)} groups of files:") + + # Create the directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Use multiprocessing to parallelize the process + with multiprocessing.Manager() as manager: + with multiprocessing.Pool(num_workers) as pool: + pool.starmap(process_single_file, [(args[0], args[1], args[2], args[3], args[4], worker_id) for worker_id, args in enumerate(worker_args)]) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Process Code15 HDF5 files and save as .h5 files.") + parser.add_argument('--input_dir', type=str, default='#CHANGEME', help='Directory contraining MIMIC-IV-ECG files.') + parser.add_argument('--output_dir', type=str, default='#CHANGEME', help='Directory to save processed .h5 files.') + parser.add_argument('--downsampling_fs', type=int, default=256, help='Desired downsampling frequency.') + parser.add_argument('--split_signal', type=int, default=5, help='Length of segments to split the signals into (in seconds). If None, no splitting is done.') + + args = parser.parse_args() + input_dir = args.input_dir + output_dir = args.output_dir + downsampling_fs = args.downsampling_fs + split_signal = args.split_signal + + # List all HDF5 files that are stored in the directory + files = os.listdir(input_dir) + path_files = [os.path.join(input_dir, file) for file in files] + + # Parallel writing + loading_and_processing_parallel(path_files, output_dir, downsampling_fs, split_signal) \ No newline at end of file diff --git a/make_cpsc2018_dataset.py b/make_cpsc2018_dataset.py new file mode 100644 index 0000000..4c49c99 --- /dev/null +++ b/make_cpsc2018_dataset.py @@ -0,0 +1,147 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +import wfdb +import numpy as np +import pandas as pd +import pickle +import argparse +import tqdm +import ast +import os +import h5py +from scipy.io import loadmat + +from process_raw_ecg import preprocess_signal + +def create_hdf5(source_dir, target_file, finetune=True, group_size=1000): + """ + Lists all the pickle files in the source directory and writes them in the .h5 file. + + Args: + source_dir (str): Source directory where pickle files are found. + target_file (str): Path to the target .h5 file. + finetune (bool): Whether data has label or not. + group_size (int): Group size in HDF5 saving. + """ + # List all the files in the folder + files = sorted(os.listdir(source_dir)) + data_group = [] + + + with h5py.File(target_file, 'w') as h5f: + for i, file in enumerate(tqdm.tqdm(files)): + with open(os.path.join(source_dir, file), 'rb') as f: + + sample = pickle.load(f) + data_group.append(sample) + + if (i + 1) % group_size == 0 or i == len(files) - 1: + + grp = h5f.create_group(f"data_group_{i // group_size}") + X_data = np.array([s['X'] for s in data_group]) + grp.create_dataset("X", data=X_data) + + if(finetune): + y_data = np.array([s['y'] for s in data_group]) + grp.create_dataset("y", data=y_data) + + data_group = [] + +def process_csv_files(args, csv_file, split_type): + """ + Process CPSC2018 ECG dataset based on a CSV file containing file paths. + + Args: + args: Command line arguments. + csv_files (str): Path to the CSV file containing file paths. + split_type (str): Type of split ('train', 'val', 'test'). + """ + # Make directory for this split if it doesn't exist + os.makedirs(os.path.join(args.output_dir, split_type), exist_ok=True) + + # Read the csv file + print(f"Processing {csv_file}...") + df = pd.read_csv(csv_file) + + for idx, row in tqdm.tqdm(df.iterrows(), total=len(df)): + + # Load the signal + try: + file_path = os.path.join(args.input_dir, f"{row['filename']}.mat") + signal = loadmat(file_path)['val'][:, :2500] + ecg = np.array(signal) + + sampling_rate = 500 # hopefully, it's 500Hz + + # Preprocess the signal and time-segment if needed + processed_ecg = preprocess_signal(ecg, sampling_rate, 0.5, 120, args.downsample_fs, None) + + # Get the labels - there's 9 classes in total and they take last 9 columns of the csv + y = row.iloc[-9:].values.astype(np.int8) + + # Write splits to pickle files + data_dict = {"X": processed_ecg, "y": y} + dump_path = os.path.join(args.output_dir, split_type, f'cpsc2018-{idx}.pkl') + + with open(dump_path, "wb") as f: + pickle.dump(data_dict, f) + + except Exception as e: + print(f"Skipped file {file_path}. Error occured: {e}") + continue + +def main_splitted(args): + """ + Main function for processing CPASC2018 dataset and saving as pickle file. + Uses splits provided by MERL ICML 2024 paper for consistency in comparison (https://github.com/cheliu-computation/MERL-ICML2024/tree/main/finetune/data_split). + """ + train_csv = os.path.join(args.csv_files_dir, "icbeb_train.csv") + val_csv = os.path.join(args.csv_files_dir, "icbeb_val.csv") + test_csv = os.path.join(args.csv_files_dir, "icbeb_test.csv") + + # Create output directory if it doesn't exist + os.makedirs(args.output_dir, exist_ok=True) + + # Write each .csv file into pickle files with X and y + process_csv_files(args, train_csv, split_type="train") + process_csv_files(args, val_csv, split_type="val") + process_csv_files(args, test_csv, split_type="test") + + # Finally write to HDF5 + to_do = [f'train', f'val', f'test'] + for td in to_do: + if os.path.exists(output_dir + '/' + td + ".h5"): + print(f"File {td}.h5 already exists!") + else: + print(f"Creating file {td}.h5.") + create_hdf5(output_dir + "/" + td, output_dir + "/" + td + ".h5") + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Process CPSC2018 ECG dataset and save as pickle files.") + parser.add_argument('--input_dir', type=str, default="#CHANGEME", help="Directory containing CPSC2018 WFDB files.") + parser.add_argument('--output_dir', type=str, default='#CHANGEME', help="Directory to save processed pickle files.") + parser.add_argument("--csv_files_dir", type=str, default='#CHANGEME') + parser.add_argument("--downsample_fs", type=int, default=256, help="Desired downsampling frequency. If None, no downsampling is done.") + + args = parser.parse_args() + output_dir = args.output_dir + + main_splitted(args) \ No newline at end of file diff --git a/make_csn_dataset.py b/make_csn_dataset.py new file mode 100644 index 0000000..14e18ee --- /dev/null +++ b/make_csn_dataset.py @@ -0,0 +1,157 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +import wfdb +import numpy as np +import pandas as pd +import pickle +import argparse +import tqdm +import ast +import os +import wfdb +import h5py + +from process_raw_ecg import preprocess_signal, time_segmenting + +leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] + +def process_csv_files(args, csv_file, split_type): + """ + Process Chapman-Shaoxing ECG dataset based on a CSV file containing file paths. + + Args: + args: Command line arguments. + csv_file (str): Path to the CSV file containing file paths. + split_type (str): Type of split ('train', 'val', 'test'). + """ + # Make directory for this split if it doesn't exist + os.makedirs(os.path.join(args.output_dir, split_type), exist_ok=True) + + # Read the csv file + print(f"Processing {csv_file}...") + df = pd.read_csv(csv_file) + + for idx, row in tqdm.tqdm(df.iterrows(), total=len(df)): + + # Load the signal + try: + file_path = os.path.join(args.input_dir, row['ecg_path'].lstrip('/')).split('.')[0] + signal, fields = wfdb.rdsamp(file_path) + ecg = np.array(signal).T # transpose to have shape (channels, samples) + + sampling_rate = fields['fs'] + + # Preprocess the signal and time-segment if needed + processed_ecg = preprocess_signal(ecg, sampling_rate, 0.5, 120, args.downsample_fs, None) + #splits = time_segmenting(processed_ecg, 5, sampling_rate, args.downsample_fs, None) + + # Get the labels - there's 38 classes in total and they take last 38 columns of the csv + y = row.iloc[-38:].values.astype(np.int8) + + # Write splits to pickle files + #for num, split in enumerate(splits): + data_dict = {"X": processed_ecg, "y": y} + dump_path = os.path.join(args.output_dir, split_type, f'chapman-{idx}.pkl') + + with open(dump_path, "wb") as f: + pickle.dump(data_dict, f) + except Exception as e: + print(f"Skipped file {file_path}. Error occured: {e}") + continue + +def main_splitted(args): + """ + Main function for processing CSN dataset and saving as pickle file. + Uses splits provided by MERL ICML 2024paper for consistency in comparison (https://github.com/cheliu-computation/MERL-ICML2024/tree/main/finetune/data_split). + Processing is done with the respect to the the pipeline my code expects. + Filtering, downsampling to 256 Hz and segmenting in 5s pieces is done here. + + Args: + args: Parsed arguments from argparse. + """ + train_csv = os.path.join(args.csv_files_dir, 'chapman_train.csv') + val_csv = os.path.join(args.csv_files_dir, 'chapman_val.csv') + test_csv = os.path.join(args.csv_files_dir, 'chapman_test.csv') + + # Create output directory if it doesn't exist + os.makedirs(args.output_dir, exist_ok=True) + + # Write each csv file into pickle files with X and y + process_csv_files(args, train_csv, split_type = 'train') + process_csv_files(args, val_csv, split_type = 'val') + process_csv_files(args, test_csv, split_type = 'test') + + # Finally, write to HDF5 + to_do = [f'train', f'val', f'test'] + for td in to_do: + if os.path.exists(output_dir + '/' + td + ".h5"): + print(f"File {td}.h5 already exists!") + else: + print(f"Creating file {td}.h5.") + create_hdf5(output_dir + "/" + td, output_dir + "/" + td + ".h5") + + +def create_hdf5(source_dir, target_file, finetune=True, group_size=1000): + """ + Lists all the pickle files in the source directory and writes them in the .h5 file. + + Args: + source_dir (str): Source directory where pickle files are found. + target_file (str): Path to the target .h5 file. + finetune (bool): Whether data has label or not. + group_size (int): Group size in HDF5 saving. + """ + # List all the files in the folder + files = sorted(os.listdir(source_dir)) + data_group = [] + + + with h5py.File(target_file, 'w') as h5f: + for i, file in enumerate(tqdm.tqdm(files)): + with open(os.path.join(source_dir, file), 'rb') as f: + + sample = pickle.load(f) + data_group.append(sample) + + if (i + 1) % group_size == 0 or i == len(files) - 1: + + grp = h5f.create_group(f"data_group_{i // group_size}") + X_data = np.array([s['X'] for s in data_group]) + grp.create_dataset("X", data=X_data) + + if(finetune): + y_data = np.array([s['y'] for s in data_group]) + grp.create_dataset("y", data=y_data) + + data_group = [] + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Process Chapman-Shaoxing ECG dataset and save as pickle files.") + parser.add_argument('--input_dir', type=str, default='#CHANGEME', help='Directory containing Chapman-Shaoxing ECG WFDB files.') + parser.add_argument('--output_dir', type=str, default='#CHANGEME', help='Directory to save processed pickle files.') + parser.add_argument('--csv_files_dir', type=str, default='#CHANGEME') + parser.add_argument('--downsample_fs', type=int, default=256, help='Desired downsampling frequency. If None, no downsampling is done.') + + args = parser.parse_args() + output_dir = args.output_dir + + main_splitted(args) + \ No newline at end of file diff --git a/make_hmc_dataset.py b/make_hmc_dataset.py new file mode 100644 index 0000000..70f2c17 --- /dev/null +++ b/make_hmc_dataset.py @@ -0,0 +1,187 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +import os +import mne +import argparse +import pickle +import h5py + +import numpy as np + +from pathlib import Path +from tqdm import tqdm + +standard = {'EEG F4-M1', 'EEG C4-M1', 'EEG O2-M1', 'EEG C3-M2', 'ECG'} + +mapping = { + 'EEG F4-M1': 'F4', + 'EEG C4-M1': 'C4', + 'EEG O2-M1': 'O2', + 'EEG C3-M2': 'C3', + 'ECG': 'I' +} +all_channels = list(standard) + +def match_annotations_and_files_and_split(input_dir): + """ + Lists all .edf files in the directory and matches signal files with annotations. + Additionally, splits into train, val and test according to https://arxiv.org/pdf/2504.19596. + First 100 subjects goes to train, next 25 for validation and next 26 for test. + + Args: + input_dir (pathlib.Path): Path to the directory with recordings. + """ + + files = input_dir.rglob('*.edf') + + # .edf files are signal and _sleepscoring.edf are annotations + signal_files = [] + annotation_files = [] + for file in files: + + file_name = file.parts[-1][:-4] + + if len(file_name.split('_')) == 1: # if there's no underscore, it's signal file + signal_files.append(file) + else: # it's annotation file + annotation_files.append(file) + + # Files are already sorted + train_files, train_ann = signal_files[:100], annotation_files[:100] + val_files, val_ann = signal_files[100:125], annotation_files[100:125] + test_files, test_ann = signal_files[125:], annotation_files[125:] + + return train_files, train_ann, val_files, val_ann, test_files, test_ann + +def dump_pickle(X, y, file_name, split): + """ + Writes data to a pickle file. + + Args: + X (np.array): Data to write to pickle. Shape: (num_samples, num_channels, sample_length) + y (np.array): Label corresponding to each sample. 5-class single label classification. Shape: (num_samples,) + file_name (str): Name to give to a pickle file. + split (str): Either "train", "val" or "test". + """ + + data_dict = {"X": X, "y": y} + dump_path = os.path.join(output_dir, split, f"{file_name}.pkl") + + # Write data to the path + with open(dump_path, "wb") as f: + pickle.dump(data_dict, f) + + print(f"Dumped into {file_name}.pkl pickle!") + +def create_hdf5(output_dir, split): + + folder_path = os.path.join(output_dir, split) + target_file = os.path.join(output_dir, f"{split}.h5") + files = sorted(os.listdir(folder_path)) + + with h5py.File(target_file, 'w') as h5f: + for i, file in enumerate(tqdm(files)): + with open(os.path.join(folder_path, file), 'rb') as f: + + sample = pickle.load(f) + + group = h5f.create_group(f"data_group_{i}") + X_data = np.array(sample["X"]) + group.create_dataset("X", data=X_data) + + y_data = np.array(sample['y']) + group.create_dataset("y", data=y_data) + + +def process_split(signal_files, annotation_files, output_dir, split): + """ + Process one split (train/val/test). + Adds annotations, extracts relevant channels, filters, and epochs based on sleep stage annotations. + + Args: + signal_files (list[Path]): List of paths to the signal .edf files. + annotation_files (list[Path]): List of paths to the annotation .edf files. + output_dir (str): Output directory to write files. + split (str): Either "train", "val" or "test". + """ + # Make directory for the split if it doesn't exist + os.makedirs(os.path.join(output_dir, split), exist_ok=True) + + for signal, annot in zip(signal_files, annotation_files): + + # Get filename for saving + file_name = signal.parts[-1][:-4] + + # Load signal and annotations + raw = mne.io.read_raw_edf(signal, preload=True, verbose=False, include=all_channels) + ann = mne.read_annotations(annot) + raw.set_annotations(ann) + + # Rename channel type to ecg for easier filtering // rename for positions + raw.set_channel_types({'ECG': 'ecg'}) + raw.rename_channels(mapping) + raw.filter(l_freq=0.1, h_freq=75.0, picks='eeg', verbose='ERROR') + raw.filter(l_freq=0.5, h_freq=120.0, picks='ecg', verbose='ERROR') + raw.notch_filter(50.0, verbose="ERROR") + + # Get annotations + events, event_id = mne.events_from_annotations(raw) + event_id = {k: v for k, v in event_id.items() if v not in [1,2]} # event_id markers 1,2 correspond to the light on/off moments + epochs = mne.Epochs(raw, events, event_id, tmin=0, tmax=30, baseline=None, verbose=True) + data = epochs.get_data()[:, :, :-1] # epochs returns 30*sampling_freq + 1 samples so we discard the last one + + # We need to create labels - it's a third column from events + # Just need to discard the 1 and 2 corresponding to lights on/off + # Subtracting 3 to shift to 0-4 values + labels = events[:, 2] + remove_on_off = labels[(labels != 1) & (labels != 2)] + labels_corrected = remove_on_off - 3 + + # Write to pickle + dump_pickle(X=data, y=labels_corrected, file_name=file_name, split=split) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--input_dir', type=str, default='#CHANGEME') + parser.add_argument('--output_dir', type=str, default='#CHANGEME') + + args = parser.parse_args() + input_dir = Path(args.input_dir) + output_dir = args.output_dir + + # List all the files in the directory + files = [f for f in input_dir.rglob('*.edf')] + train_files, train_ann, val_files, val_ann, test_files, test_ann = match_annotations_and_files_and_split(input_dir) + + # Make output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Process each split + process_split(train_files, train_ann, output_dir, split="train") + process_split(val_files, val_ann, output_dir, split="val") + process_split(test_files, test_ann, output_dir, split="test") + + # Write to HDF5 + create_hdf5(output_dir, split="train") + create_hdf5(output_dir, split="val") + create_hdf5(output_dir, split="test") + + diff --git a/make_mimic_iv_dataset.py b/make_mimic_iv_dataset.py new file mode 100644 index 0000000..506f255 --- /dev/null +++ b/make_mimic_iv_dataset.py @@ -0,0 +1,127 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +import wfdb +import numpy as np +import pandas as pd +import pickle +import argparse +import tqdm +import ast +import os +import wfdb +import subprocess +import multiprocessing +import h5py + +from process_raw_ecg import preprocess_signal, time_segmenting + +def dump_data(data_slice, output_dir, data_group_name, file_idx, channel_info): + output_path = os.path.join(output_dir, f"MIMIC-IV-ECG_12_channels_{file_idx}.h5") + with h5py.File(output_path, "a") as h5f: + if h5f.attrs.get("channel_name") is None: + h5f.attrs['channel_names'] = channel_info + grp = h5f.create_group(data_group_name) + X_data = np.array(data_slice, dtype=np.float16) + grp.create_dataset("X", data=X_data, dtype='float16') + +def process_single_file(file_paths, output_dir, downsample_fs, split_signal, file_idx, worker_id, batch_size=300): + batch_data = [] + batch_idx = 0 + channel_info = None + + def flush_batch(): + nonlocal batch_idx + if not batch_data: + return + concatenated = np.concatenate(batch_data, axis=0) + group_name = f"{worker_id}_{batch_idx}" + dump_data(concatenated, output_dir, group_name, file_idx, channel_info) + batch_data.clear() + batch_idx += 1 + + with tqdm.tqdm(total=len(file_paths), desc=f"Worker {worker_id}", position=worker_id) as pbar: + for file_path in file_paths: + record_path = file_path.split('.')[0] + record = wfdb.rdrecord(record_path) + + raw_ecg = np.array(record.p_signal).T + sampling_rate = record.fs + channel_info = record.sig_name + + processed_ecg = preprocess_signal(raw_ecg, sampling_rate, 0.5, 120, downsample_fs, None) + + p_min = processed_ecg.min() + p_max = processed_ecg.max() + if (p_max - p_min) == 0: + pbar.update(1) + continue + + processed_ecg = (processed_ecg - p_min) / (p_max - p_min) + merged_splits = processed_ecg.reshape(1, processed_ecg.shape[0], processed_ecg.shape[1]) + + if split_signal is not None: + signal_splitted = time_segmenting(merged_splits, split_signal, sampling_rate, downsample_fs, None) + merged_splits = np.concatenate(signal_splitted, axis=0) + + batch_data.append(merged_splits) + + if len(batch_data) >= batch_size: + flush_batch() + + pbar.update(1) + + flush_batch() # flush remaining records that didn't fill a full batch + +def loading_and_processing_parallel(files, output_dir, downsampling_fs, split_signal, max_sessions_per_file=30000): + num_workers = 27 + worker_args = [] + + print(max_sessions_per_file) + for i in range(0, len(files), max_sessions_per_file): + worker_args.append((files[i:i+max_sessions_per_file], output_dir, downsampling_fs, split_signal, int(i / max_sessions_per_file))) + + print(f"Processing {len(worker_args)} groups of files:") + os.makedirs(output_dir, exist_ok=True) + + with multiprocessing.Manager() as manager: + with multiprocessing.Pool(num_workers) as pool: + pool.starmap(process_single_file, [ + (args[0], args[1], args[2], args[3], args[4], worker_id, 300) + for worker_id, args in enumerate(worker_args) + ]) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Process MIMIC-IV-ECG .hea and .dat files and save as .h5 files.") + parser.add_argument('--input_dir', type=str, default='#CHANGEME', help='Directory containing MIMIC-IV-ECG files.') + parser.add_argument('--output_dir', type=str, default='#CHANGEME', help='Directory to save processed .h5 files.') + parser.add_argument('--downsampling_fs', type=int, default=256, help='Desired downsampling frequency.') + parser.add_argument('--split_signal', type=int, default=None, help='Length of segments to split the signals into (in seconds). If None, no splitting is done.') + + args = parser.parse_args() + + if os.path.exists('/scratch2/msc25h9/mimic-iv-ecg-files.pkl'): + with open('/scratch2/msc25h9/mimic-iv-ecg-files.pkl', 'rb') as f: + files = pickle.load(f) + else: + result = subprocess.run(["find", args.input_dir, "-maxdepth", str(5), "-type", "f", "-name", "*.hea"], stdout=subprocess.PIPE, text=True) + files = result.stdout.splitlines() + + loading_and_processing_parallel(files, args.output_dir, args.downsampling_fs, args.split_signal) \ No newline at end of file diff --git a/make_ptbxl_dataset.py b/make_ptbxl_dataset.py new file mode 100644 index 0000000..b1db63c --- /dev/null +++ b/make_ptbxl_dataset.py @@ -0,0 +1,158 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +import pandas as pd +import numpy as np + +import wfdb +import ast +import argparse +import os +import pickle +import h5py + +from process_raw_ecg import preprocess_signal +from itertools import chain +from tqdm import tqdm + +def create_hdf5(source_dir, target_file, finetune=True, group_size=1000): + """ + Lists all the pickle files in the source directory and writes them in the .h5 file. + + Args: + source_dir (str): Source directory where pickle files are found. + target_file (str): Path to the target .h5 file. + finetune (bool): Whether data has label or not. + group_size (int): Group size in HDF5 saving. + """ + # List all the files in the folder + files = sorted(os.listdir(source_dir)) + data_group = [] + + + with h5py.File(target_file, 'w') as h5f: + for i, file in enumerate(tqdm(files)): + with open(os.path.join(source_dir, file), 'rb') as f: + + sample = pickle.load(f) + data_group.append(sample) + + if (i + 1) % group_size == 0 or i == len(files) - 1: + + grp = h5f.create_group(f"data_group_{i // group_size}") + X_data = np.array([s['X'] for s in data_group]) + grp.create_dataset("X", data=X_data) + + if(finetune): + y_data = np.array([s['y'] for s in data_group]) + grp.create_dataset("y", data=y_data) + + data_group = [] + +def process_csv_files(args, csv_file, setup, split_type): + """ + Process csv files and dump into pickle files. The csv files are taken from MERL ICML 2024 repository (https://github.com/cheliu-computation/MERL-ICML2024/tree/main/finetune/data_split). + + Args: + args: Parsed arguments from argparse. + csv_file (str): Path to the csv file. + setup (str): One of ['super_class', 'sub_class', 'form' and 'rhythm']. + split (str): One of ['train', 'val', 'test']. + """ + # Make output directory for this setup if it doesn't exist + os.makedirs(os.path.join(args.output_dir, setup, split_type), exist_ok=True) + + # Read the csv file + print(f"Processing {csv_file}...") + df = pd.read_csv(csv_file) + + for idx, row in tqdm(df.iterrows(), total=df.shape[0]): + + # Load the signal, we always use the 500 Hz version + file_path = os.path.join(args.input_dir, row['filename_hr']) + signal, _ = wfdb.rdsamp(file_path) + + # Preprocess the signal + processed_signal = preprocess_signal(signal.T, args.sampling_rate, 0.5, 120, args.downsampling_fs, None) + + # Get the labels + if setup == 'super_class': # superclass setup has 5 classes / last 5 columns of every row are labels + y = row.iloc[-5:].values.astype(np.int8) + elif setup == 'sub_class': # subclass setup has 23 classes / last 23 columns of every row are labels + y = row.iloc[-23:].values.astype(np.int8) + elif setup == 'form': # form setup has 19 classes / last 19 columns of every row are labels + y = row.iloc[-19:].values.astype(np.int8) + elif setup == 'rhythm': # rhythm setup has 12 classes / last 12 columns of every row are labels + y = row.iloc[-12:].values.astype(np.int8) + else: + raise ValueError("Invalid setup provided.") + + data_dict = {"X": processed_signal, "y": y} + dump_path = os.path.join(args.output_dir, setup, split_type, f'ptb-xl-{idx}.pkl') + + with open(dump_path, "wb") as f: + pickle.dump(data_dict, f) + +def main_splitted(args, setup): + """ + Main function for processing PTB_XL dataset and saving as pickle file. + Uses splits provided by MERL paper for consistency in comparison. + Processing is done with the respect to the the pipeline my code expects. + Filtering, downsampling and segmenting in 5s pieces is done here. + + Args: + args: Parsed arguments from argparse. + setup (str): One of ['super_class', 'sub_class', 'form' and 'rhythm']. + """ + train_csv = os.path.join(args.csv_files_dir, setup, f'ptbxl_{setup}_train.csv') + val_csv = os.path.join(args.csv_files_dir, setup, f'ptbxl_{setup}_val.csv') + test_csv = os.path.join(args.csv_files_dir, setup, f'ptbxl_{setup}_test.csv') + + # Create output directory for this setup if it doesn't exist + os.makedirs(os.path.join(args.output_dir, setup), exist_ok=True) + + # Write each csv into pickle files with X and y + process_csv_files(args, train_csv, setup, split_type='train') + process_csv_files(args, val_csv, setup, split_type='val') + process_csv_files(args, test_csv, setup, split_type='test') + + # Finally, write to HDF5 + to_do = ['train', 'val', 'test'] + for td in to_do: + if os.path.exists(args.output_dir + '/' + setup + "/" + td + "/" + ".h5"): + print(f"File {td}.h5 already exists!") + else: + print(f"Creating file {td}.h5.") + create_hdf5(args.output_dir + "/" + setup + "/" + td, args.output_dir + "/" + setup + "/" + td + ".h5") + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Process PTB-XL dataset files and save as pickle.") + parser.add_argument('--input_dir', type=str, default='#CHANGEME') + parser.add_argument('--output_dir', type=str, default='#CHANGEME') + parser.add_argument('--csv_files_dir', type=str, default='#CHANGEME') + parser.add_argument('--sampling_rate', type=int, default=500) + parser.add_argument('--downsampling_fs', type=int, default=256) + parser.add_argument('--setup', type=str, default=5, help="Either super_class, sub_class, form or rhythm.") + + args = parser.parse_args() + setup = args.setup + + # Run main function for + main_splitted(args, setup=setup) \ No newline at end of file diff --git a/make_pulsedb_dataset.py b/make_pulsedb_dataset.py new file mode 100644 index 0000000..9737728 --- /dev/null +++ b/make_pulsedb_dataset.py @@ -0,0 +1,152 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +import wfdb +import numpy as np +import pandas as pd +import scipy.io +import pickle +import argparse +import multiprocessing +import tqdm +import h5py +import ast +import os + +from process_raw_ecg import preprocess_signal, time_segmenting +from pathlib import Path +from sklearn.model_selection import train_test_split +from mat73 import loadmat + +def dump_data(data_slice, output_dir, data_group_name, file_idx, channel_info, dataset): + + data_group = data_slice + output_path = os.path.join(output_dir, f"{dataset}_2_channels_{file_idx}.h5") + + with h5py.File(output_path, "a") as h5f: + + if h5f.attrs.get("channel_names") is None: + h5f.attrs["channel_names"] = channel_info + + grp = h5f.create_group(data_group_name) + X_data = np.array(data_group, dtype=np.float16) + grp.create_dataset("X", data=X_data, dtype='float16') + + print(f"Saved {len(data_slice)} samples to {output_path}.") + +def process_single_file(file_paths, output_dir, sampling_rate, upsample_fs, split_signal, dataset, file_idx, worker_id): + """ + Process a single ,mat file from VitalDB dataset. Friendly for multiprocessing. + """ + + with tqdm.tqdm(total=len(file_paths), desc=f"Worker {worker_id}", position=worker_id) as pbar: + for file_path in file_paths: + channel_info = ['II', 'PPG'] + + # Handle case if file is corrupted + try: + session_name = file_path.parts[-1][:-4] + data = loadmat(file_path) + raw_ecg = np.array(data['Subj_Wins']['ECG_Raw']) + raw_ppg = np.array(data['Subj_Wins']['PPG_Raw']) + + except Exception as e: + print(f"Standard loading failed for {session_name}. Falling back to scipy.io.loadmat().") + try: + session_name = file_path.parts[-1][:-4] + data = scipy.io.loadmat(file_path) + unpacked = data['Subj_Wins'][0, 0] + raw_ecg = np.array(unpacked['ECG_Raw']).T + raw_ppg = np.array(unpacked['PPG_Raw']).T + + except Exception as scipy_e: + print(f"Error loading {file_path} even with scipy.io. {scipy_e}.") + return None + + # Preprocess signals and concatenate + processed_ecg = preprocess_signal(raw_ecg, sampling_rate, 0.5, 120, None, upsample_fs) + processed_ppg = preprocess_signal(raw_ppg, sampling_rate, 0.5, 12, None, upsample_fs) + + if processed_ecg.ndim < 3 or processed_ppg.ndim < 3: + processed_ecg = processed_ecg.reshape(1, 1, -1) + processed_ppg = processed_ppg.reshape(1, 1, -1) + + signals = np.concatenate((processed_ecg, processed_ppg), axis=1) + + merged_splits = signals + + # Time splitting + if split_signal is not None: + + signal_splitted = time_segmenting(signals, split_signal, sampling_rate, None, upsample_fs) + merged_splits = np.concatenate(signal_splitted, axis=0) + + dump_data(merged_splits, output_dir, session_name, file_idx, channel_info, dataset) + + pbar.update(1) + +def loading_and_processing_parallel(files, output_dir, sampling_rate, upsampling_fs, split_signal, dataset, max_sessions_per_file=300): + """ + Load and process all .mat files in the input directory in parallel. + + Args: + files (list): List of .mat filenames to process. + output_dir (str): Directory to save processed pickle files. + sampling_rate (int): Original sampling rate of the signals. + upsampling_fs (int): Desired upsampling frequency. + split_signal (int or None): Length of segments to split the signals into (in seconds). If None, no splitting is done. + dataset (str): Dataset name. + """ + num_workers = 4 + worker_args = [] + + # Prepare arguments for parallel processing + for i in range(0, len(files), max_sessions_per_file): + worker_args.append((files[i:i+max_sessions_per_file], output_dir, sampling_rate, upsampling_fs, split_signal, dataset, int(i / max_sessions_per_file))) + + print(f"Processing {len(worker_args)} groups of files:") + os.makedirs(output_dir, exist_ok=True) + + # Use multiprocessing to parallelize the process + with multiprocessing.Manager() as manager: + with multiprocessing.Pool(num_workers) as pool: + pool.starmap(process_single_file, [(args[0], args[1], args[2], args[3], args[4], args[5], args[6], worker_id) for worker_id, args in enumerate(worker_args)]) + +if __name__ == "__main__": + + parser = argparse.ArgumentParser(description="Process VitalDB .mat files and save as pickle files.") + parser.add_argument('--dataset', type=str, default='VitalDB', help='Either VitalDB or MimicDB.') + parser.add_argument('--input_dir', type=str, default='#CHANGEME', help='Directory containing VitalDB .mat files.') + parser.add_argument('--output_dir', type=str, default='#CHANGEME', help='Directory to save processed pickle files.') + parser.add_argument('--sampling_rate', type=int, default=125, help='Original sampling rate of the signals.') + parser.add_argument('--upsampling_fs', type=int, default=256, help='Desired upsampling frequency.') + parser.add_argument('--split_signal', type=int, default=5, help='Length of segments to split the signals into (in seconds). If None, no splitting is done.') + + args = parser.parse_args() + dataset = args.dataset + input_dir = Path(args.input_dir) + output_dir = Path(args.output_dir) + sampling_rate = args.sampling_rate + upsampling_fs = args.upsampling_fs + split_signal = args.split_signal + + # List all .mat files in the input directory + files = [f for f in input_dir.rglob('*.mat')] + loading_and_processing_parallel(files, output_dir, sampling_rate, upsampling_fs, split_signal, dataset) + diff --git a/make_seed_vii_dataset.py b/make_seed_vii_dataset.py new file mode 100644 index 0000000..6811975 --- /dev/null +++ b/make_seed_vii_dataset.py @@ -0,0 +1,202 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +import mne +import argparse +import os +import numpy as np +import csv +import datetime +import pickle +import h5py +from tqdm import tqdm +# Code taken from SEED-VII instructions +sampling_freq=256 + +# We decide to label +# neutral = 0 +# happy = 1 +# sad = 2 +# disgust = 3 +# anger = 4 +# fear = 5 +# surprise = 6 + +per_session_labels = { + '1': [1, 0, 3, 2, 4, 4, 2, 3, 0, 1, 1, 0, 3, 2, 4, 4, 2, 3, 0, 1], + '2': [4, 2, 5, 0, 6, 6, 0, 5, 2, 4, 4, 2, 5, 0, 6, 6, 0, 5, 2, 4], + '3': [1, 6, 3, 5, 4, 4, 5, 3, 6, 1, 1, 6, 3, 5, 4, 4, 5, 3, 6, 1], + '4': [3, 2, 5, 6, 1, 1, 6, 5, 2, 3, 3, 2, 5, 6, 1, 1, 6, 5, 2, 3] +} + +channels = ['FP1', 'FPZ', 'FP2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'OZ', 'O2', 'CB2', 'ECG'] + +def process_and_split(input_dir, output_dir, file_dict): + """ + Process sessions and place it in correct split (train/val/test). Following splits from PhysioOmni. + 20 subjects - 4 sessions per each - 20 trials in each + First 10 trials from each session go to the train split. + Next 5 trials from each session go to the val split and last 5 to the test split. + """ + collect_num = 0 + for key, value in file_dict.items(): + for raw_file in value: + print("Starting with path:", raw_file) + # Prepare the path to the file + raw_path = os.path.join(input_dir, raw_file) + raw = mne.io.read_raw_cnt(raw_path, ecg=['ECG']) + + # Drop useless channels + raw.drop_channels(['M1', 'M2', 'HEO', 'VEO']) + raw.load_data() + + # Preprocessing + raw.filter(l_freq=0.1, h_freq=75.0, picks='eeg', verbose='ERROR') + raw.filter(l_freq=0.5, h_freq=120.0, picks='ecg', verbose='ERROR') + raw.notch_filter(50.0, verbose='ERROR') + + if raw.info['sfreq'] != sampling_freq: + raw.resample(sampling_freq) + + # Get triggers for each trial + trigger, _ = mne.events_from_annotations(raw) + data, times = raw.get_data(return_times=True) + + t = trigger[:, 0] + + # We need to separatrly handle files authors specify are not accurately handled with triggers + if raw_file == "14_20221015_1.cnt": + t = [] + start = datetime.datetime.strptime('14:25:34', '%H:%M:%S') + with open('#CHANGEME') as f: + trigger = csv.reader(f) + for row in trigger: + end = datetime.datetime.strptime(row[1].split(' ')[-1], '%H:%M:%S.%f') + time_diff = end.timestamp() - start.timestamp() + t.append(int(round(time_diff * sampling_freq))) + elif raw_file == "9_20221111_3.cnt": + t = [] + start = datetime.datetime.strptime('14:01:27', '%H:%M:%S') + with open(os.path.join('#CHANGEME')) as f: + trigger = csv.reader(f) + for row in trigger: + end = datetime.datetime.strptime(row[1].split(' ')[-1], '%H:%M:%S.%f') + time_diff = end.timestamp() - start.timestamp() + t.append(int(round(time_diff * sampling_freq))) + + session_idx = int(raw_file[-5]) + for i in range(20): + # Get the data between two triggers + preprocessed_clip = data[:, t[2 * i]:t[2 * i + 1]] + num = preprocessed_clip.shape[1] // sampling_freq + collect_num += num + + # We want to fetch int number of seconds and reshape such that each second is segment + signal = preprocessed_clip[:, :num * sampling_freq] + split_signal = signal.reshape(signal.shape[0], num, sampling_freq).transpose(1, 0, 2) + + for idx in range(num): + data_dict = {"X": split_signal[idx, :, :], 'y': per_session_labels[str(session_idx)][i]} + dump_name = f"{key}_seed_{session_idx}_{i}_{idx}.pkl" + + # first 10 trials from session go to train split + if i < 10: + dump_path = os.path.join(output_dir, "train", dump_name) + + elif i < 15 and i >= 10: + dump_path = os.path.join(output_dir, "val", dump_name) + + else: + dump_path = os.path.join(output_dir, "test", dump_name) + + with open(dump_path, "wb") as f: + pickle.dump(data_dict, f) + print("Current sample count:", collect_num) + +def create_hdf5(source_dir, target_file, finetune=True, group_size=1000): + """ + Lists all the pickle files in the source directory and writes them in the .h5 file. + + Args: + source_dir (str): Source directory where pickle files are found. + target_file (str): Path to the target .h5 file. + finetune (bool): Whether data has label or not. + group_size (int): Group size in HDF5 saving. + """ + # List all the files in the folder + files = sorted(os.listdir(source_dir)) + data_group = [] + + + with h5py.File(target_file, 'w') as h5f: + for i, file in enumerate(tqdm(files)): + with open(os.path.join(source_dir, file), 'rb') as f: + + sample = pickle.load(f) + data_group.append(sample) + + if (i + 1) % group_size == 0 or i == len(files) - 1: + + grp = h5f.create_group(f"data_group_{i // group_size}") + X_data = np.array([s['X'] for s in data_group]) + grp.create_dataset("X", data=X_data) + + if(finetune): + y_data = np.array([s['y'] for s in data_group]) + print(y_data) + grp.create_dataset("y", data=y_data) + + data_group = [] + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--input_dir', type=str, default='#CHANGEME') + parser.add_argument('--output_dir', type=str, default='#CHANGEME') + + # Parse arguments + args = parser.parse_args() + input_dir = args.input_dir + output_dir = args.output_dir + + # List all files in the directory and prepare per subjects + raw_files = os.listdir(input_dir) + file_dict = {} + + for file in raw_files: + if file[:-15] not in file_dict.keys(): + file_dict[file[:-15]] = [file] + else: + file_dict[file[:-15]].append(file) + + os.makedirs(os.path.join(output_dir, "train"), exist_ok=True) + os.makedirs(os.path.join(output_dir, "val"), exist_ok=True) + os.makedirs(os.path.join(output_dir, "test"), exist_ok=True) + + process_and_split(input_dir, output_dir, file_dict) + + # Finally, write to HDF5 + to_do = ['train', 'val', 'test'] + for td in to_do: + if os.path.exists(output_dir + '/' + td + '.h5'): + print(f"File {td}.h5 already exists!") + else: + print(f"Creating file {td}.h5.") + create_hdf5(output_dir + "/" + td, output_dir + "/" + td + ".h5") + \ No newline at end of file diff --git a/make_siena_dataset.py b/make_siena_dataset.py new file mode 100644 index 0000000..5a3dcc1 --- /dev/null +++ b/make_siena_dataset.py @@ -0,0 +1,124 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +import mne +import argparse +import tqdm +import h5py +import os + +import numpy as np + +from pathlib import Path + +standard = { + 'EEG FP1', 'EEG F3', 'EEG C3', 'EEG P3', 'EEG O1', 'EEG F7', 'EEG T3', 'EEG T5', 'EEG FC1', 'EEG FC5', + 'EEG CP1', 'EEG CP5', 'EEG F9', 'EEG FZ', 'EEG CZ', 'EEG PZ', 'EEG FP2', 'EEG F4', 'EEG C4', 'EEG P4', 'EEG O2', + 'EEG F8', 'EEG T4', 'EEG T6', 'EEG FC2', 'EEG FC6', 'EEG CP2', 'EEG CP6', 'EEG F10' +} + +all_channels = list(standard) +sampling_freq = 256 # target sampling frequency +segment_length = 5 # segment length in seconds + +def create_hdf5(sliced_data, output_dir, session_name): + """ + Create HDF5 file from sliced data. + + Args: + sliced_data (np.ndarray): Sliced data of shape (n_channels, n_intervals, interval_size). + output_dir (str or Path): Directory to save the HDF5 file. + session_name (str): Name of the session/file. + """ + os.makedirs(output_dir, exist_ok=True) + target_path = os.path.join(output_dir, f"SIENA_29_channels.h5") + + with h5py.File(target_path, "a") as hdf5_file: + group = hdf5_file.create_group(session_name) + X_data = np.array(sliced_data, dtype=np.float16) + group.create_dataset("X", data=X_data, dtype='float16') + +def process_and_save_files_to_hdf5(file_paths, output_dir): + """ + Process EDF files, extract standard EEG channels, filter, downsample, segment and save to HDF5. + + Args: + file_paths (list of Path): List of paths to EDF files. + output_dir (str or Path): Directory to save processed HDF5 files. + """ + + for file_path in tqdm.tqdm(file_paths): + + print(f"Processing file: {file_path.name}") + session_name = file_path.parts[-1][:-4] + # Load raw EDF file + raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False) + + # First map all channel names to upper-case to have consistency + upper_mapping = {ch: ch.upper() for ch in raw.ch_names} + raw.rename_channels(upper_mapping) + raw.pick(all_channels) + + # Strip EEG part from mapping + mapping = {ch: ch.split(' ')[1] for ch in raw.ch_names} + raw.rename_channels(mapping) + + # Filtering (bandpass and notch) + raw.filter(l_freq=0.1, h_freq=75.0, verbose="ERROR") + raw.notch_filter(50, verbose="ERROR") + + # Downsample + if raw.info['sfreq'] != sampling_freq: + raw.resample(sampling_freq) + data = raw.get_data(units='uV') + + # Check for NaN and Infs (just for notification) + if np.isnan(data).any() or np.isinf(data).any(): + print(f"Warning: NaN or Inf values found in file {file_path.name}") + + n_channels, n_times = data.shape + print(f"Data shape (channels x timepoints): {data.shape}") + interval_size = segment_length * sampling_freq + num_intervals = n_times // interval_size + + new_sliced_data = data[:, :num_intervals * interval_size].reshape(num_intervals, n_channels, interval_size) + + create_hdf5(new_sliced_data, output_dir, session_name) + + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--input_dir', type=str, default='#CHANGEME', help='Input directory containing raw EDF files') + parser.add_argument('--output_dir', type=str, default='#CHANGEME') + + # Parse arguments + args = parser.parse_args() + input_dir = args.input_dir + output_dir = args.output_dir + + # Find all EDF files in the input directory + input_dir = Path(input_dir) + file_paths = [f for f in input_dir.rglob('*.edf')] + print(f"Found {len(file_paths)} EDF files.") + + process_and_save_files_to_hdf5(file_paths, output_dir) + + + diff --git a/make_wesad_dataset.py b/make_wesad_dataset.py new file mode 100644 index 0000000..85a1be0 --- /dev/null +++ b/make_wesad_dataset.py @@ -0,0 +1,274 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* + +import pickle +import argparse +import h5py +import os + +import numpy as np + +from process_raw_ecg import preprocess_signal +from tqdm import tqdm + +def binary_classification(ecg_process, ppg_process, labels, subject, split_type): + """ + Label IDs: + - 0 = not defined/transient + - 1 = baseline + - 2 = stress + - 3 = amusement + - 4 = meditation + - 5/6/7 = should be ignored in this dataset + + Binary classification: Combined baseline and amusement sessions form non-stress class, while the stress class is original stress session. + """ + # Each segment is 1min long + subseq_size_label = 700 * 60 + subseq_size_data = 256 * 60 + + # Extact unique labels and their occurence points + uniques, uniques_index = np.unique(labels, return_index=True) + counter_subject = 0 + for unique, unique_startidx in zip(uniques, uniques_index): + if unique not in [1, 2, 3]: + continue + + while True: + if unique_startidx // subseq_size_label * subseq_size_data > ppg_process.shape[1]: + break + for next_idx in range(unique_startidx, len(labels)): + if unique != labels[next_idx]: + break + + totalsubseqs = (next_idx - unique_startidx) // subseq_size_label + startidx = unique_startidx // subseq_size_label * subseq_size_data + data_temp_60sec_ppg = ppg_process[:, startidx:startidx + totalsubseqs * subseq_size_data] + data_temp_60sec_ecg = ecg_process[:, startidx:startidx + totalsubseqs * subseq_size_data] + + if unique_startidx // subseq_size_label * subseq_size_data + totalsubseqs * subseq_size_data > ppg_process.shape[1]: + totalsubseqs = data_temp_60sec_ppg.shape[0] // subseq_size_data + data_temp_60sec_ppg = data_temp_60sec_ppg[:totalsubseqs * subseq_size_data, :] + data_temp_60sec_ecg = data_temp_60sec_ecg[:totalsubseqs * subseq_size_data, :] + + if totalsubseqs == 0: + break + + data_temp_60sec_ppg = np.stack(np.split(data_temp_60sec_ppg, totalsubseqs, 1), 0) + data_temp_60sec_ecg = np.stack(np.split(data_temp_60sec_ecg, totalsubseqs, 1), 0) + + # Concat both signals + total_concat = np.concatenate([data_temp_60sec_ecg, data_temp_60sec_ppg], axis=1) + unique_temp = 1 if unique == 3 else unique + label_temp_60sec = np.repeat(unique_temp, totalsubseqs) + + print(f"{subject}:", total_concat.shape, label_temp_60sec) + + # Write each sample piece to pickle + for idx in range(len(label_temp_60sec)): + + data_dict = {"X": total_concat[idx], "y": label_temp_60sec[idx]} + dump_path = os.path.join(output_dir, split_type, f"wesad_{subject}_{counter_subject}.pkl") + counter_subject += 1 + + with open(dump_path, 'wb') as f: + pickle.dump(data_dict, f) + + if unique != 4: + break + +def multiclass_classification(ecg_process, ppg_process, labels, subject, split_type): + """ + Multiclass classification: Stress, Baseline, Amusement and Meditation. + """ + # Each segment is 1min long + subseq_size_label = 700 * 60 + subseq_size_data = 256 * 60 + + # Extact unique labels and their occurence points + uniques, uniques_index = np.unique(labels, return_index=True) + counter_subject = 0 + for unique, unique_startidx in zip(uniques, uniques_index): + + flag = False + if unique not in [1, 2, 3, 4]: + continue + + while True: + if unique_startidx // subseq_size_label * subseq_size_data > ppg_process.shape[1]: + break + for next_idx in range(unique_startidx, len(labels)): + if unique != labels[next_idx]: + break + + totalsubseqs = (next_idx - unique_startidx) // subseq_size_label + startidx = unique_startidx // subseq_size_label * subseq_size_data + data_temp_60sec_ppg = ppg_process[:, startidx:startidx + totalsubseqs * subseq_size_data] + data_temp_60sec_ecg = ecg_process[:, startidx:startidx + totalsubseqs * subseq_size_data] + + if unique_startidx // subseq_size_label * subseq_size_data + totalsubseqs * subseq_size_data > ppg_process.shape[1]: + totalsubseqs = data_temp_60sec_ppg.shape[0] // subseq_size_data + data_temp_60sec_ppg = data_temp_60sec_ppg[:totalsubseqs * subseq_size_data, :] + data_temp_60sec_ecg = data_temp_60sec_ecg[:totalsubseqs * subseq_size_data, :] + + if totalsubseqs == 0: + break + + data_temp_60sec_ppg = np.stack(np.split(data_temp_60sec_ppg, totalsubseqs, 1), 0) + data_temp_60sec_ecg = np.stack(np.split(data_temp_60sec_ecg, totalsubseqs, 1), 0) + + # Concat both signals + total_concat = np.concatenate([data_temp_60sec_ecg, data_temp_60sec_ppg], axis=1) + label_temp_60sec = np.repeat(unique, totalsubseqs) + + print(f"{subject}:", total_concat.shape, label_temp_60sec) + # Write each sample piece to pickle + for idx in range(len(label_temp_60sec)): + + data_dict = {"X": total_concat[idx], "y": label_temp_60sec[idx]} + dump_path = os.path.join(output_dir, split_type, f"wesad_{subject}_{counter_subject}.pkl") + counter_subject += 1 + + with open(dump_path, 'wb') as f: + pickle.dump(data_dict, f) + + if unique != 4: + break + else: + if flag: + break + flag = True + new_label = labels[next_idx:] + uniques_temp, uniques_indedata_temp = np.unique(new_label, return_index=True) + try: + unique_startidx = uniques_indedata_temp[np.where(uniques_temp == 4)][0] + next_idx + except IndexError: + break + +def process_split(input_dir, output_dir, subjects, split_type, classification_type): + """ + Process PPG and ECG data, depending on the classification type of the task. + """ + # Create directory for the split + os.makedirs(os.path.join(output_dir, split_type), exist_ok=True) + + # Iterate over all subjects + for subject in subjects: + file_path = os.path.join(input_dir, subject, f"{subject}.pkl") + + with open(file_path, 'rb') as f: + data = pickle.load(f, encoding='latin1') + + # Extract the channels and labels + ecg = data['signal']['chest']['ECG'].T + ppg = data['signal']['wrist']['BVP'].T + labels = data['label'] + + # Preprocess the signal + ecg_process = preprocess_signal(ecg, fs=700, low=0.5, high=120.0, downsample_fs=256, upsample_fs=None) + ppg_process = preprocess_signal(ppg, fs=64, low=0.5, high=8.0, downsample_fs=None, upsample_fs=256) + + # Depending on the classification_type prepare the data + if classification_type == "binary": + binary_classification(ecg_process, ppg_process, labels, subject, split_type) + else: + multiclass_classification(ecg_process, ppg_process, labels, subject, split_type) + +def create_hdf5(source_dir, target_file, finetune=True, group_size=1000): + """ + Lists all the pickle files in the source directory and writes them in the .h5 file. + + Args: + source_dir (str): Source directory where pickle files are found. + target_file (str): Path to the target .h5 file. + finetune (bool): Whether data has label or not. + group_size (int): Group size in HDF5 saving. + """ + # List all the files in the folder + files = sorted(os.listdir(source_dir)) + data_group = [] + + + with h5py.File(target_file, 'w') as h5f: + for i, file in enumerate(tqdm(files)): + with open(os.path.join(source_dir, file), 'rb') as f: + + sample = pickle.load(f) + data_group.append(sample) + + if (i + 1) % group_size == 0 or i == len(files) - 1: + + grp = h5f.create_group(f"data_group_{i // group_size}") + X_data = np.array([s['X'] for s in data_group]) + grp.create_dataset("X", data=X_data) + + if(finetune): + y_data = np.array([s['y'] for s in data_group]) + print(y_data) + grp.create_dataset("y", data=y_data) + + data_group = [] + +if __name__ == "__main__": + + parser = argparse.ArgumentParser() + parser.add_argument('--input_dir', type=str, default='#CHANGEME', help='Input directory containing raw files.') + parser.add_argument('--output_dir', type=str, default='#CHANGEME') + + # Parse arguments + args = parser.parse_args() + input_dir = args.input_dir + output_dir = args.output_dir + + # Extract all the folders related to the subjects + folders = os.listdir(input_dir) + folders.sort() + + # We need only ones that have S in the name + subjects = [] + for name in folders: + if name[0] != "S": + continue + subjects.append(name) + + # Create output directory if it doesn't exist + os.makedirs(output_dir, exist_ok=True) + + # Taking pre-determined train/val/test splits from Pulse-PPG (binary classification) + process_split(input_dir, output_dir, subjects[:11], split_type='train_binary', classification_type='binary') + process_split(input_dir, output_dir, subjects[11:13], split_type='val_binary', classification_type='binary') + process_split(input_dir, output_dir, subjects[13:15], split_type='test_binary', classification_type='binary') + + process_split(input_dir, output_dir, subjects[:11], split_type='train_multiclass', classification_type='multiclass') + process_split(input_dir, output_dir, subjects[11:13], split_type='val_multiclass', classification_type='multiclass') + process_split(input_dir, output_dir, subjects[13:15], split_type='test_multiclass', classification_type='multiclass') + + # Finally, write to HDF5 + to_do = ['train_binary', 'val_binary', 'test_binary', 'train_multiclass', 'val_multiclass', 'test_multiclass'] + for td in to_do: + if os.path.exists(output_dir + '/' + td + '.h5'): + print(f"File {td}.h5 already exists!") + else: + print(f"Creating file {td}.h5.") + create_hdf5(output_dir + "/" + td, output_dir + "/" + td + ".h5") + + + + \ No newline at end of file diff --git a/multiloader_data_module_PanLUNA.yaml b/multiloader_data_module_PanLUNA.yaml new file mode 100644 index 0000000..2595770 --- /dev/null +++ b/multiloader_data_module_PanLUNA.yaml @@ -0,0 +1,720 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +# @package _global_ +defaults: + - dataset_types + +data_module: + _target_: 'data_module.multiloader_data_module.VaryingChannelsDataModule' + cfg: + num_workers: ${num_workers} + batch_size: ${batch_size} + train_val_split_ratio: 0.8 + datasets: + TUEG_20_channels_0: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 20 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_20_channels_1: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 20 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_20_channels_2: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 20 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_0: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_11: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_12: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_13: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_14: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_16: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_18: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_19: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_1: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_20: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_23: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_24: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_31: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_3: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_4: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_5: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_15: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_21: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_22: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_25: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_26: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_27: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_28: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_20_channels_3: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 20 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_29: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_2: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_30: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_6: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_7: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_8: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_9: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + TUEG_22_channels_10: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 22 + TUEG_22_channels_17: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "${env:SCRATCH_PATH}/TUEG/TUEG_22_channels_17.h5" + num_channels: 22 + channels: ${dataset_types.tueg.channels} + location_fn: ${dataset_types.tueg.location_fn} + sensor_type: ${dataset_types.tueg.sensor_type} + VitalDB_2_channels_0: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + VitalDB_2_channels_1: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + VitalDB_2_channels_2: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + VitalDB_2_channels_3: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + VitalDB_2_channels_4: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + VitalDB_2_channels_5: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + VitalDB_2_channels_6: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + VitalDB_2_channels_7: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + VitalDB_2_channels_8: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + VitalDB_2_channels_9: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + MimicDB_2_channels_0: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + MimicDB_2_channels_1: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + MimicDB_2_channels_2: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + MimicDB_2_channels_3: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + MimicDB_2_channels_4: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + MimicDB_2_channels_5: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + MimicDB_2_channels_6: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 2 + channels: ${dataset_types.pulsedb.channels} + location_fn: ${dataset_types.pulsedb.location_fn} + sensor_type: ${dataset_types.pulsedb.sensor_type} + MIMIC-IV-ECG_12_channels_0: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_1: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_2: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_3: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_4: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_5: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_6: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_7: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_8: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_9: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_10: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_11: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_12: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_13: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_14: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_15: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_16: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_17: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_18: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_19: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_20: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_21: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_22: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_23: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_24: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_25: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + MIMIC-IV-ECG_12_channels_26: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.mimicIV.channels} + location_fn: ${dataset_types.mimicIV.location_fn} + sensor_type: ${dataset_types.mimicIV.sensor_type} + CODE15_0: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_1: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_2: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_3: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_4: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_5: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_6: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_7: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_8: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_9: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_10: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_11: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_12: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_13: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_14: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_15: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_16: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + CODE15_17: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 12 + channels: ${dataset_types.code15.channels} + location_fn: ${dataset_types.code15.location_fn} + sensor_type: ${dataset_types.code15.sensor_type} + SIENA_29_channels: + _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset' + hdf5_file: "#CHANGEME" + num_channels: 29 + channels: ${dataset_types.siena.channels} + location_fn: ${dataset_types.siena.location_fn} + sensor_type: ${dataset_types.siena.sensor_type} \ No newline at end of file diff --git a/pretrain_task_PanLUNA.py b/pretrain_task_PanLUNA.py new file mode 100644 index 0000000..5d2944e --- /dev/null +++ b/pretrain_task_PanLUNA.py @@ -0,0 +1,273 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* + +import torch +import pytorch_lightning as pl +import hydra +import wandb +import torch_optimizer as torch_optim +import torch.nn.functional as F +import matplotlib.pyplot as plt +import numpy as np +from criterion.query_specialization_criterion import QuerySpecializationCriterion + +class ChannelWiseNormalize: + def __init__(self, eps=1e-8): + self.eps = eps + + def __call__(self, tensor): + with torch.no_grad(): + # tensor: (B, C, T) + mean = tensor.mean(dim=2, keepdim=True) + std = tensor.std(dim=2, keepdim=True) + return (tensor - mean) / (std + self.eps) + +class MaskTask(pl.LightningModule): + """ + PyTorch Lightning module for training a model with masked reconstruction. + + Args: + hparams (DictConfig): Parameters and configurations loaded via Hydra. + """ + def __init__(self, hparams): + super().__init__() + self.save_hyperparameters(hparams) + self.model = hydra.utils.instantiate(self.hparams.model) + self.criterion = hydra.utils.instantiate(self.hparams.criterion) + self.query_specialization_criterion = QuerySpecializationCriterion(**self.hparams.query_specialization_criterion) + self.patch_size = self.hparams.masking.patch_size + self.masking_ratio = self.hparams.masking.masking_ratio + self.unmasked_loss_coeff = self.hparams.masking.unmasked_loss_coeff + # Enable normalization if specified in parameters + if self.hparams.input_normalization is not None and self.hparams.input_normalization.normalize: + self.normalize = True + self.normalize_fct = ChannelWiseNormalize() + else: + self.normalize = False + + self.plot_batches_flags = {'12': True, '22': True, '2': True, '20': False, '29': True, '64': True} + + def generate_mask(self, batch_size, C, T): + """ + Generate a boolean mask for block-wise rectangular masking. + + Args: + batch_size (int): Batch size. + C (int): Number of channels (height). + T (int): Temporal length (width). + + Returns: + torch.BoolTensor: Boolean mask of shape (batch_size, C, T), + with True in the masked regions. + """ + patch_H, patch_W = self.patch_size + masking_ratio = self.masking_ratio + + # Calculate total number of patch rectangles + num_rectangles = (C // patch_H) * (T // patch_W) + num_to_mask = int(num_rectangles * masking_ratio) + + row_indices = torch.arange(0, C, patch_H) + col_indices = torch.arange(0, T, patch_W) + rectangles = [(i, j) for i in row_indices for j in col_indices] + + # Randomly select which rectangles to mask + selected_indices = torch.randperm(num_rectangles)[:num_to_mask] + + mask = torch.zeros(batch_size, C, T, dtype=torch.bool).to(self.device) + + # Set mask to True in the selected regions + for idx in selected_indices: + r, c = rectangles[idx] + mask[:, r:r + patch_H, c:c + patch_W] = True + + return mask + + def training_step(self, batch, batch_idx): + """ + Training step: apply mask, normalize and compute loss. + + Args: + batch (torch.Tensor): Input batch. + batch_idx (int): Batch index. + + Returns: + torch.Tensor: Loss value. + """ + X = batch["input"] + channel_locations = batch["channel_locations"] + channel_names = batch.get("channel_names", None) + sensor_type = batch["sensor_type"] + mask = self.generate_mask(X.shape[0], X.shape[1], X.shape[2]) + + if self.normalize: + X = self.normalize_fct(X) + + # Pass masked input through the model to get reconstruction and embeddings + x_reconstructed, x_original, attention_scores = self.model(X, mask, channel_locations, sensor_type, channel_names) + if torch.isnan(X).any() or torch.isinf(X).any(): + print("!!! Input X_ORIGINAL contains NaN or Inf at step 0!!!") + # Print the data source or indices for debugging + raise RuntimeError("Input data is corrupted.") + if torch.isnan(x_reconstructed).any(): + print("!!! Model output X_RECONSTRUCTED contains NaN at step", self.global_step, "!!!") + raise ValueError("NaN detected in model output.") + # Compute loss only on masked parts + masked_loss, unmasked_loss = self.criterion(x_reconstructed, x_original, mask, self.patch_size[1]) + loss = masked_loss + self.unmasked_loss_coeff * unmasked_loss + if self.hparams.query_specialization_criterion is not None: + query_specialization_loss = self.query_specialization_criterion(attention_scores) + loss += query_specialization_loss + self.log('query_specialization_loss', query_specialization_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.log('train_loss', masked_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + return loss + + def validation_step(self, batch, batch_idx): + """ + Validation step: apply mask, normalize, compute loss and log signals. + + Args: + batch (torch.Tensor): Input batch. + batch_idx (int): Batch index. + + Returns: + torch.Tensor: Loss value. + """ + X = batch["input"] + channel_locations = batch["channel_locations"] + channel_names = batch.get("channel_names", None) + sensor_type = batch["sensor_type"] + mask = self.generate_mask(X.shape[0], X.shape[1], X.shape[2]) + + if self.normalize: + X = self.normalize_fct(X) + + x_reconstructed, x_original, attention_scores = self.model(X, mask, channel_locations, channel_names) + + masked_loss, unmasked_loss = self.criterion(x_reconstructed, x_original, mask, self.patch_size[1]) + loss = masked_loss + self.unmasked_loss_coeff * unmasked_loss + + if self.hparams.query_specialization_criterion is not None: + query_specialization_loss = self.query_specialization_criterion(attention_scores) + loss += query_specialization_loss + self.log('query_specialization_loss', query_specialization_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) + + self.log('val_loss', loss, prog_bar=True, on_step=True, on_epoch=True, logger=True, sync_dist=True) + + # Fixed indices for logging signals + random_indices = [6, 16, 30] + dataset_id = str(sensor_type.shape[1]) + # Log signals with mask only for the first validation batch + if self.plot_batches_flags[dataset_id]: + self.log_signals_with_mask( + x_original.float(), + x_reconstructed.float(), + f"Reconstruction {dataset_id}", + mask, + batch_indices=random_indices, + indice_batch=batch_idx + ) + self.plot_batches_flags[dataset_id] = False + + return loss + + def on_validation_epoch_end(self): + # Restart batches for plotting - assumes no shuffling + self.plot_batches_flags = {'12': True, '22': True, '2': True, '20': False, '29': True, '64': True} + + def configure_optimizers(self): + """ + Configure optimizer and scheduler based on parameters. + + Returns: + dict: Dictionary with optimizer and scheduler for PyTorch Lightning. + """ + if self.hparams.optimizer.optim == "SGD": + optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.optimizer.lr, momentum=0.9) + elif self.hparams.optimizer.optim == 'Adam': + optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.optimizer.lr, weight_decay=0.01) + elif self.hparams.optimizer.optim == 'AdamW': + optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.hparams.optimizer.lr) + elif self.hparams.optimizer.optim == 'LAMB': + optimizer = torch_optim.Lamb(self.model.parameters(), lr=self.hparams.optimizer.lr) + else: + raise NotImplementedError("No valid optim name") + + scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer, total_training_opt_steps=self.trainer.estimated_stepping_batches) + lr_scheduler_config = { + "scheduler": scheduler, + "interval": "step", + "frequency": 1, + } + + return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config} + + def lr_scheduler_step(self, scheduler, metric): + scheduler.step_update(num_updates=self.global_step) + + def log_signals_with_mask(self, original, reconstructed, title, mask=None, batch_indices=None, indice_batch=None): + """ + Log original and reconstructed signals highlighting masked regions. + + Args: + original (torch.Tensor): Original signals. + reconstructed (torch.Tensor): Signals reconstructed by the model. + mask (torch.BoolTensor, optional): Applied mask. + batch_indices (list[int], optional): Batch indices to log. + indice_batch (int, optional): Current batch index. + """ + patch_H, patch_W = self.patch_size + batch_size, C, T = original.shape + + for batch_idx in batch_indices: + original_signal = original[batch_idx] + reconstructed_signal = reconstructed[batch_idx] + + fig, ax = plt.subplots(1, 1, figsize=(15, 6)) + + # Limit visualization to the first patch_H channels + original_signal_c2 = original_signal[1:2, :] + reconstructed_signal_c2 = reconstructed_signal[1:2, :] + + ax.plot(original_signal_c2[0].cpu().numpy(), label='Original Channel 0', color='blue', alpha=0.7) + ax.plot(reconstructed_signal_c2[0].cpu().numpy(), label='Reconstructed Channel 0', color='orange', alpha=0.7) + + if mask is not None: + mask_c2 = mask[batch_idx, 1:2, :] + indices = [] + + # Highlight masked regions with a light gray transparent band + for i in range(patch_H): + for j in range(T // patch_W): + if mask_c2[i, j * patch_W:(j + 1) * patch_W].all(): + ax.axvspan(j * patch_W, (j + 1) * patch_W, color='lightgray', alpha=0.1) + indices.append(j) + + # Remove duplicates and sort highlighted indices + indices_array = np.array(indices) + indices_array = np.unique(indices) + + ax.set_title(f"Signal Reconstruction - batch_ {batch_idx}") + ax.legend() + + # Log the figure on TensorBoard with batch and index in the title + self.logger.experiment.add_figure(f'{title}_batch {batch_idx}', fig, self.current_epoch) + plt.close(fig) diff --git a/pretrain_task_PanLUNA.yaml b/pretrain_task_PanLUNA.yaml new file mode 100644 index 0000000..cf36213 --- /dev/null +++ b/pretrain_task_PanLUNA.yaml @@ -0,0 +1,22 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +# @package _global_ +task: + _target_: 'tasks.pretrain_task_PanLUNA.MaskTask' \ No newline at end of file diff --git a/pretraining_datasets_PanLUNA.py b/pretraining_datasets_PanLUNA.py new file mode 100644 index 0000000..cf958c4 --- /dev/null +++ b/pretraining_datasets_PanLUNA.py @@ -0,0 +1,96 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +import torch +import h5py +import numpy as np +from models.modules.lead_positions import ( + get_channel_indices, + get_channel_locations, + map_lead_labels_to_angles, +) + +class BioSignal_Dataset(torch.utils.data.Dataset): + """ + Unified dataset class for pretraining data. + Supports EEG (TUEG, Siena) and ECG/ECG&PPG (Code15, MIMIC-IV, PulseDB) HDF5 files. + + Args: + hdf5_file: Path to the .h5 file. + channels: List of channels. + location_fn: "eeg" or "ecg". + sensor_type: 0 (ECG), 1(EEG), 2(PPG). + num_channels: Number of channels taken from total channel list. + """ + def __init__( + self, + hdf5_file: str, + channels: list[str], + location_fn: str, + sensor_type: int | list[int], + num_channels: int | None = None, + ): + super().__init__() + channel_names = channels[:num_channels] if num_channels else channels + self.num_channels = len(channel_names) + + if location_fn == "eeg": + locs = np.stack(get_channel_locations(channel_names), axis=0) + self.channel_locations = torch.from_numpy(locs).float() + else: + self.channel_locations = torch.FloatTensor(map_lead_labels_to_angles(channel_names)) + + self.channel_indices = torch.tensor(get_channel_indices(channel_names), dtype=torch.long) + + if isinstance(sensor_type, list): + self.sensor_type = torch.tensor(sensor_type, dtype=torch.long) + else: + self.sensor_type = torch.full((self.num_channels, ), sensor_type, dtype=torch.long) + + self.data = h5py.File(hdf5_file, "r") + self.keys = list(self.data.keys()) + self.index_map = [] + + for key in self.keys: + if key == 'index_map': + continue + group_size = len(self.data[key]["X"]) + self.index_map.extend([(key, i) for i in range(group_size)]) + + def __len__(self): + return len(self.index_map) + + def __getitem__(self, index): + + group_key, sample_idx = self.index_map[index] + grp = self.data[group_key] + X = torch.FloatTensor(grp["X"][sample_idx]) + + item = { + "input": X, + "channel_names": self.channel_indices, + "channel_locations": self.channel_locations, + "sensor_type": self.sensor_type + } + + return item + + def __del__(self): + if hasattr(self, "data"): + self.data.close() \ No newline at end of file diff --git a/process_raw_ecg.py b/process_raw_ecg.py new file mode 100644 index 0000000..7939181 --- /dev/null +++ b/process_raw_ecg.py @@ -0,0 +1,121 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +import wfdb +import numpy as np +import pandas as pd +import pickle +import argparse +import tqdm +import ast +import os +import math + +from scipy.signal import butter, filtfilt, iirnotch, resample_poly +from mat73 import loadmat + +def preprocess_signal(waveform, fs, low, high, downsample_fs=None, upsample_fs=None): + """ + Preprocess ECG&PPG waveform. + + wavefrom (np.array): Numpy array of shape (12x5000). + band (tuple): Tuple containing band for bandpass filtering. + fs (int): Sampling frequency. + low (int): Low cut frequency for bandpass filter. + high (int): High cut frequency for bandpass filter. + downsample_fs (int): Downsampling frequency. If None, no downsampling is applied. + upsample_fs (int): Upsampling frequency. If None, no upsampling is applied + + """ + # -------- Missing values ----------------- + # At least checking for them + # For now PTB-XL doesn't have it + check_nan_inf = (~np.isfinite(waveform)).sum() + if check_nan_inf != 0: + # Here we will put some kind of processing at one moment + waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0) + + # ------- Upsampling ----------------- + if upsample_fs is not None: + waveform = resample_poly(waveform, up=upsample_fs, down=fs, axis=-1) + fs = upsample_fs + + # -------- Bandpass filtering ------------ + b, a = butter(4, [low/(fs/2), high/(fs/2)], btype="band") + waveform_bp = filtfilt(b, a, waveform, axis=-1) + + # -------- Notch filtering ------------- + notch_freq_1 = 50 + notch_freq_2 = 60 + q = 30 # quality factor? + + # First notch filter + b_notch, a_notch = iirnotch(notch_freq_1, q, fs) + waveform_notch = filtfilt(b_notch, a_notch, waveform_bp, axis=-1) + + # Second notch filter + b_notch, a_notch = iirnotch(notch_freq_2, q, fs) + waveform_notch = filtfilt(b_notch, a_notch, waveform_notch, axis=-1) + + # -------- Downsampling ---------------- + if downsample_fs is not None: + # This is general down-sampling pipeline working for any frequency + # We need to find greatest common divisor of the two to have up and down factors + gcd_value = math.gcd(fs, downsample_fs) + up_factor = downsample_fs // gcd_value + down_factor = fs // gcd_value + + resampled = resample_poly(waveform_notch, up=up_factor, down=down_factor, axis=-1) + else: + resampled = waveform_notch + + return resampled + +def time_segmenting(signal, split_signal, sampling_rate, downsample_fs=None, upsample_fs=None): + """ + Segments the data into splits of length split_signal (in seconds). + Assumes time length of signal si dividable with split_signal, i.e. there's no need for padding. + downsample_fs and upsample_fs are used to determine effective sampling rate after preprocessing. + One of them should be None. + + Args: + signal (np.array): Processed signal of shape [num_channels, time_length]. + split_signal (int): Length of time segment in seconds for splitting the signal. + sampling_rate (int): Original sampling rate of the signal. + downsample_fs (int): Downsampling frequency. If None, no downsampling is applied. + upsample_fs (int): Upsampling frequency. If None, no upsampling is applied. + """ + + if downsample_fs is not None: + split_samples = downsample_fs * split_signal + elif upsample_fs is not None: + split_samples = upsample_fs * split_signal + else: + split_samples = sampling_rate * split_signal + + # We assume it's dividable + assert signal.shape[-1] % split_samples == 0, \ + f"Signal length ({signal.shape[-1]} must be exactly divisible by {split_samples})." + + num_sections = signal.shape[-1] // split_samples + + # Splitting + pieces = np.split(signal, num_sections, axis=-1) + + return pieces \ No newline at end of file From b7d68e797872928969844f21030a0c4f3f69e72b Mon Sep 17 00:00:00 2001 From: Marija Zelic <145926170+masazelic@users.noreply.github.com> Date: Mon, 27 Apr 2026 09:51:06 +0200 Subject: [PATCH 2/3] Add citation for PanLUNA model in README --- README.md | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/README.md b/README.md index 4d71334..b9157ab 100644 --- a/README.md +++ b/README.md @@ -502,6 +502,15 @@ If you find this work useful, please cite the respective papers: primaryClass={cs.AI}, url={https://arxiv.org/abs/2603.19100}, } +@misc{zelic2026panluna, + title={PanLUNA: An Efficient and Robust Query-Unified Multimodal Model for Edge Biosignal Intelligence}, + author={Marija Zelic and Anna Tegon and Yawei Li and Thorir Mar Ingolfsson}, + year={2026}, + eprint={2604.04297}, + archivePrefix={arXiv}, + primaryClass={cs.AI}, + url={https://arxiv.org/abs/2604.04297}, +} ``` ## License From 2f23a79a2bd99bb73dec9b24201b156d383c5c90 Mon Sep 17 00:00:00 2001 From: Masa Zelic Date: Tue, 28 Apr 2026 15:02:50 +0200 Subject: [PATCH 3/3] File org --- .../data_module/dataset_types.yaml | 0 ...netune_data_module_multimodal_PanLUNA.yaml | 0 ...finetune_data_module_unimodal_PanLUNA.yaml | 0 .../multiloader_data_module_PanLUNA.yaml | 0 config/experiment/PanLUNA_finetune.yaml | 120 ++++++++++++++++++ config/experiment/PanLUNA_pretrain.yaml | 97 ++++++++++++++ .../model/PanLUNA_finetune.yaml | 0 .../model/PanLUNA_pretrain.yaml | 0 .../task/finetune_task_PanLUNA.yaml | 0 .../task/pretrain_task_PanLUNA.yaml | 0 .../finetuning_multimodal_datasets_PanLUNA.py | 0 .../finetuning_unimodal_datasets_PanLUNA.py | 0 .../pretraining_datasets_PanLUNA.py | 0 PanLUNA.md => docs/model/PanLUNA.md | 0 .../make_code15_dataset.py | 0 .../make_cpsc2018_dataset.py | 0 .../make_csn_dataset.py | 0 .../make_hmc_dataset.py | 0 .../make_mimic_iv_dataset.py | 0 .../make_ptbxl_dataset.py | 0 .../make_pulsedb_dataset.py | 0 .../make_seed_vii_dataset.py | 0 .../make_siena_dataset.py | 0 .../make_wesad_dataset.py | 0 .../process_raw_ecg.py | 0 PanLUNA.py => models/PanLUNA.py | 0 .../modules/lead_positions.py | 0 requirements.txt | 1 + .../finetune_task_PanLUNA.py | 0 .../pretrain_task_PanLUNA.py | 0 30 files changed, 218 insertions(+) rename dataset_types.yaml => config/data_module/dataset_types.yaml (100%) rename finetune_data_module_multimodal_PanLUNA.yaml => config/data_module/finetune_data_module_multimodal_PanLUNA.yaml (100%) rename finetune_data_module_unimodal_PanLUNA.yaml => config/data_module/finetune_data_module_unimodal_PanLUNA.yaml (100%) rename multiloader_data_module_PanLUNA.yaml => config/data_module/multiloader_data_module_PanLUNA.yaml (100%) create mode 100644 config/experiment/PanLUNA_finetune.yaml create mode 100644 config/experiment/PanLUNA_pretrain.yaml rename PanLUNA_finetune.yaml => config/model/PanLUNA_finetune.yaml (100%) rename PanLUNA_pretrain.yaml => config/model/PanLUNA_pretrain.yaml (100%) rename finetune_task_PanLUNA.yaml => config/task/finetune_task_PanLUNA.yaml (100%) rename pretrain_task_PanLUNA.yaml => config/task/pretrain_task_PanLUNA.yaml (100%) rename finetuning_multimodal_datasets_PanLUNA.py => datasets/finetuning_multimodal_datasets_PanLUNA.py (100%) rename finetuning_unimodal_datasets_PanLUNA.py => datasets/finetuning_unimodal_datasets_PanLUNA.py (100%) rename pretraining_datasets_PanLUNA.py => datasets/pretraining_datasets_PanLUNA.py (100%) rename PanLUNA.md => docs/model/PanLUNA.md (100%) rename make_code15_dataset.py => make_datasets/make_code15_dataset.py (100%) rename make_cpsc2018_dataset.py => make_datasets/make_cpsc2018_dataset.py (100%) rename make_csn_dataset.py => make_datasets/make_csn_dataset.py (100%) rename make_hmc_dataset.py => make_datasets/make_hmc_dataset.py (100%) rename make_mimic_iv_dataset.py => make_datasets/make_mimic_iv_dataset.py (100%) rename make_ptbxl_dataset.py => make_datasets/make_ptbxl_dataset.py (100%) rename make_pulsedb_dataset.py => make_datasets/make_pulsedb_dataset.py (100%) rename make_seed_vii_dataset.py => make_datasets/make_seed_vii_dataset.py (100%) rename make_siena_dataset.py => make_datasets/make_siena_dataset.py (100%) rename make_wesad_dataset.py => make_datasets/make_wesad_dataset.py (100%) rename process_raw_ecg.py => make_datasets/process_raw_ecg.py (100%) rename PanLUNA.py => models/PanLUNA.py (100%) rename lead_positions.py => models/modules/lead_positions.py (100%) rename finetune_task_PanLUNA.py => tasks/finetune_task_PanLUNA.py (100%) rename pretrain_task_PanLUNA.py => tasks/pretrain_task_PanLUNA.py (100%) diff --git a/dataset_types.yaml b/config/data_module/dataset_types.yaml similarity index 100% rename from dataset_types.yaml rename to config/data_module/dataset_types.yaml diff --git a/finetune_data_module_multimodal_PanLUNA.yaml b/config/data_module/finetune_data_module_multimodal_PanLUNA.yaml similarity index 100% rename from finetune_data_module_multimodal_PanLUNA.yaml rename to config/data_module/finetune_data_module_multimodal_PanLUNA.yaml diff --git a/finetune_data_module_unimodal_PanLUNA.yaml b/config/data_module/finetune_data_module_unimodal_PanLUNA.yaml similarity index 100% rename from finetune_data_module_unimodal_PanLUNA.yaml rename to config/data_module/finetune_data_module_unimodal_PanLUNA.yaml diff --git a/multiloader_data_module_PanLUNA.yaml b/config/data_module/multiloader_data_module_PanLUNA.yaml similarity index 100% rename from multiloader_data_module_PanLUNA.yaml rename to config/data_module/multiloader_data_module_PanLUNA.yaml diff --git a/config/experiment/PanLUNA_finetune.yaml b/config/experiment/PanLUNA_finetune.yaml new file mode 100644 index 0000000..99e1347 --- /dev/null +++ b/config/experiment/PanLUNA_finetune.yaml @@ -0,0 +1,120 @@ + +# @package _global_ +tag: PanLUNA_finetune +model_size: tiny +gpus: 4 +num_nodes: 1 +num_workers: 8 +batch_size: 256 + +training: True +final_validate: True +final_test: True +finetune_pretrained: True +resume: False +find_unused_parameters: True + +label_smoothing: 0 +layerwise_lr_decay: 0.75 +scheduler_type: cosine +# Change for different type of task +# bc - binary classification +# mlp - multilabel classification +# mcc - multiclass classification +classification_type: "#CHANGEME" + +pretrained_checkpoint_path: "#CHANGEME" + +callbacks: + progress_bar: + _target_: 'pytorch_lightning.callbacks.TQDMProgressBar' + refresh_rate: 50 + early_stopping: + _target_: 'pytorch_lightning.callbacks.EarlyStopping' + monitor: 'val_loss' + patience: 12 + mode: 'min' + verbose: True + +input_normalization: + normalize: True + +finetuning: + mode: "lora" # options: "full", "freeze_encoder", "lora" + lora: + r: 16 + alpha: 32 + dropout: 0.10 + target_modules: + - 'blocks.0.attn.qkv_proj' + - 'blocks.1.attn.qkv_proj' + - 'blocks.2.attn.qkv_proj' + - 'blocks.3.attn.qkv_proj' + - 'blocks.4.attn.qkv_proj' + - 'blocks.5.attn.qkv_proj' + - 'blocks.0.attn.proj' + - 'blocks.1.attn.proj' + - 'blocks.2.attn.proj' + - 'blocks.3.attn.proj' + - 'blocks.4.attn.proj' + - 'blocks.5.attn.proj' + - 'cross_attn.cross_attention' + - 'cross_attn.cross_attention.out_proj' + - 'cross_attn.query_self_attn.layers.0.self_attn' + - 'cross_attn.query_self_attn.layers.1.self_attn' + - 'cross_attn.query_self_attn.layers.2.self_attn' + - 'cross_attn.query_self_attn.layers.0.self_attn.out_proj' + - 'cross_attn.query_self_attn.layers.1.self_attn.out_proj' + - 'cross_attn.query_self_attn.layers.2.self_attn.out_proj' + +io: + checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints + version: ${tag}_${model_size}_finetune + base_output_path: "#CHANGEME" + +defaults: + - override /data_module: finetune_data_module_unimodal_PanLUNA # or finetune_data_module_multimodal_PanLUNA + - override /model: PanLUNA_finetune + - override /scheduler: cosine + - override /task: finetune_task_PanLUNA + - override /criterion: finetune_criterion + +model: + patch_size: 32 + embed_dim: 64 + num_heads: 2 + depth: 6 + num_queries: 4 + mlp_ratio: 4 + drop_path: 0.1 + num_classes: "#CHANGEME" + +trainer: + accelerator: gpu + num_nodes: ${num_nodes} + devices: ${gpus} + strategy: ddp + max_epochs: 100 + precision: bf16-mixed + +model_checkpoint: + save_last: True + monitor: "val_loss" + mode: "min" + save_top_k: 1 + every_n_epochs: 1 + +optimizer: + optim: 'AdamW' + lr: 5.0e-4 + betas: [0.9, 0.999] + weight_decay: 0.05 + +scheduler: + trainer: ${trainer} + min_lr: 5.0e-6 # minimum LR for the cosine scheduler + warmup_lr_init: 2.5e-7 # initial LR for the warmup phase + warmup_epochs: 0 + + + diff --git a/config/experiment/PanLUNA_pretrain.yaml b/config/experiment/PanLUNA_pretrain.yaml new file mode 100644 index 0000000..2042d20 --- /dev/null +++ b/config/experiment/PanLUNA_pretrain.yaml @@ -0,0 +1,97 @@ +#*----------------------------------------------------------------------------* +#* Copyright (C) 2026 ETH Zurich, Switzerland * +#* SPDX-License-Identifier: Apache-2.0 * +#* * +#* Licensed under the Apache License, Version 2.0 (the "License"); * +#* you may not use this file except in compliance with the License. * +#* You may obtain a copy of the License at * +#* * +#* http://www.apache.org/licenses/LICENSE-2.0 * +#* * +#* Unless required by applicable law or agreed to in writing, software * +#* distributed under the License is distributed on an "AS IS" BASIS, * +#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * +#* See the License for the specific language governing permissions and * +#* limitations under the License. * +#* * +#* Author: Marija Zelic * +#* Author: Thorir Mar Ingolfsson * +#*----------------------------------------------------------------------------* +# @package _global_ +tag: PanLUNA_pretrain +model_size: tiny +gpus: 4 +num_nodes: 1 +num_workers: 8 +batch_size: 1024 + +final_validate: True +final_test: False +resume: False +pretrained_checkpoint_path: null +freeze_encoder: False +io: + checkpoint_dirpath: "#CHANGEME" + checkpoint_filepath: null + +defaults: + - override /data_module: multiloader_data_module_PanLUNA + - override /model: PanLUNA_pretrain + - override /scheduler: cosine + - override /task: pretrain_task_PanLUNA + - override /criterion: pretrain_criterion + +query_specialization_criterion: + loss_type: 'l2' + loss_coeff: 0.8 + +finetuning: + mode: null + +model: + patch_size: 32 + embed_dim: 64 + num_heads: 2 + depth: 6 + num_queries: 4 + mlp_ratio: 4 + drop_path: 0.0 + num_classes: 0 + +masking: + patch_size: [1, 32] + masking_ratio: 0.5 + unmasked_loss_coeff: 0.05 + +input_normalization: + normalize: True + +scheduler: + trainer: ${trainer} + min_lr: 2.5e-7 # minimum LR for the cosine scheduler + warmup_lr_init: 2.5e-7 # initial LR for the warmup phase + warmup_epochs: 10 + +trainer: + accelerator: gpu + num_nodes: ${num_nodes} + devices: ${gpus} + max_epochs: 150 + strategy: ddp + accumulate_grad_batches: 1 + check_val_every_n_epoch: 1 + use_distributed_sampler: False + precision: bf16-mixed + gradient_clip_val: 1.0 + +model_checkpoint: + save_last: True + monitor: "val_loss" + mode: "min" + save_top_k: 1 + +optimizer: + optim: 'AdamW' + lr: 1.25e-4 + betas: [0.9, 0.98] + weight_decay: 0.05 \ No newline at end of file diff --git a/PanLUNA_finetune.yaml b/config/model/PanLUNA_finetune.yaml similarity index 100% rename from PanLUNA_finetune.yaml rename to config/model/PanLUNA_finetune.yaml diff --git a/PanLUNA_pretrain.yaml b/config/model/PanLUNA_pretrain.yaml similarity index 100% rename from PanLUNA_pretrain.yaml rename to config/model/PanLUNA_pretrain.yaml diff --git a/finetune_task_PanLUNA.yaml b/config/task/finetune_task_PanLUNA.yaml similarity index 100% rename from finetune_task_PanLUNA.yaml rename to config/task/finetune_task_PanLUNA.yaml diff --git a/pretrain_task_PanLUNA.yaml b/config/task/pretrain_task_PanLUNA.yaml similarity index 100% rename from pretrain_task_PanLUNA.yaml rename to config/task/pretrain_task_PanLUNA.yaml diff --git a/finetuning_multimodal_datasets_PanLUNA.py b/datasets/finetuning_multimodal_datasets_PanLUNA.py similarity index 100% rename from finetuning_multimodal_datasets_PanLUNA.py rename to datasets/finetuning_multimodal_datasets_PanLUNA.py diff --git a/finetuning_unimodal_datasets_PanLUNA.py b/datasets/finetuning_unimodal_datasets_PanLUNA.py similarity index 100% rename from finetuning_unimodal_datasets_PanLUNA.py rename to datasets/finetuning_unimodal_datasets_PanLUNA.py diff --git a/pretraining_datasets_PanLUNA.py b/datasets/pretraining_datasets_PanLUNA.py similarity index 100% rename from pretraining_datasets_PanLUNA.py rename to datasets/pretraining_datasets_PanLUNA.py diff --git a/PanLUNA.md b/docs/model/PanLUNA.md similarity index 100% rename from PanLUNA.md rename to docs/model/PanLUNA.md diff --git a/make_code15_dataset.py b/make_datasets/make_code15_dataset.py similarity index 100% rename from make_code15_dataset.py rename to make_datasets/make_code15_dataset.py diff --git a/make_cpsc2018_dataset.py b/make_datasets/make_cpsc2018_dataset.py similarity index 100% rename from make_cpsc2018_dataset.py rename to make_datasets/make_cpsc2018_dataset.py diff --git a/make_csn_dataset.py b/make_datasets/make_csn_dataset.py similarity index 100% rename from make_csn_dataset.py rename to make_datasets/make_csn_dataset.py diff --git a/make_hmc_dataset.py b/make_datasets/make_hmc_dataset.py similarity index 100% rename from make_hmc_dataset.py rename to make_datasets/make_hmc_dataset.py diff --git a/make_mimic_iv_dataset.py b/make_datasets/make_mimic_iv_dataset.py similarity index 100% rename from make_mimic_iv_dataset.py rename to make_datasets/make_mimic_iv_dataset.py diff --git a/make_ptbxl_dataset.py b/make_datasets/make_ptbxl_dataset.py similarity index 100% rename from make_ptbxl_dataset.py rename to make_datasets/make_ptbxl_dataset.py diff --git a/make_pulsedb_dataset.py b/make_datasets/make_pulsedb_dataset.py similarity index 100% rename from make_pulsedb_dataset.py rename to make_datasets/make_pulsedb_dataset.py diff --git a/make_seed_vii_dataset.py b/make_datasets/make_seed_vii_dataset.py similarity index 100% rename from make_seed_vii_dataset.py rename to make_datasets/make_seed_vii_dataset.py diff --git a/make_siena_dataset.py b/make_datasets/make_siena_dataset.py similarity index 100% rename from make_siena_dataset.py rename to make_datasets/make_siena_dataset.py diff --git a/make_wesad_dataset.py b/make_datasets/make_wesad_dataset.py similarity index 100% rename from make_wesad_dataset.py rename to make_datasets/make_wesad_dataset.py diff --git a/process_raw_ecg.py b/make_datasets/process_raw_ecg.py similarity index 100% rename from process_raw_ecg.py rename to make_datasets/process_raw_ecg.py diff --git a/PanLUNA.py b/models/PanLUNA.py similarity index 100% rename from PanLUNA.py rename to models/PanLUNA.py diff --git a/lead_positions.py b/models/modules/lead_positions.py similarity index 100% rename from lead_positions.py rename to models/modules/lead_positions.py diff --git a/requirements.txt b/requirements.txt index 8804a2b..ab79d32 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,6 +10,7 @@ tqdm pandas mne nvidia-nccl-cu11 +peft psutil pyyaml tensorboardX diff --git a/finetune_task_PanLUNA.py b/tasks/finetune_task_PanLUNA.py similarity index 100% rename from finetune_task_PanLUNA.py rename to tasks/finetune_task_PanLUNA.py diff --git a/pretrain_task_PanLUNA.py b/tasks/pretrain_task_PanLUNA.py similarity index 100% rename from pretrain_task_PanLUNA.py rename to tasks/pretrain_task_PanLUNA.py