From 0ca0479a6e82490b16f61425639ffea946b5b80d Mon Sep 17 00:00:00 2001
From: Marija Zelic <145926170+masazelic@users.noreply.github.com>
Date: Mon, 27 Apr 2026 09:43:59 +0200
Subject: [PATCH 1/3] Add fixed PanLUNA to repos
cleared PanLUNA
---
PanLUNA.md | 148 ++++
PanLUNA.py | 311 ++++++++
PanLUNA_finetune.yaml | 30 +
PanLUNA_pretrain.yaml | 30 +
README.md | 49 +-
dataset_types.yaml | 33 +
finetune_data_module_multimodal_PanLUNA.yaml | 109 +++
finetune_data_module_unimodal_PanLUNA.yaml | 49 ++
finetune_task_PanLUNA.py | 419 +++++++++++
finetune_task_PanLUNA.yaml | 22 +
finetuning_multimodal_datasets_PanLUNA.py | 126 ++++
finetuning_unimodal_datasets_PanLUNA.py | 97 +++
lead_positions.py | 212 ++++++
make_code15_dataset.py | 146 ++++
make_cpsc2018_dataset.py | 147 ++++
make_csn_dataset.py | 157 ++++
make_hmc_dataset.py | 187 +++++
make_mimic_iv_dataset.py | 127 ++++
make_ptbxl_dataset.py | 158 ++++
make_pulsedb_dataset.py | 152 ++++
make_seed_vii_dataset.py | 202 ++++++
make_siena_dataset.py | 124 ++++
make_wesad_dataset.py | 274 +++++++
multiloader_data_module_PanLUNA.yaml | 720 +++++++++++++++++++
pretrain_task_PanLUNA.py | 273 +++++++
pretrain_task_PanLUNA.yaml | 22 +
pretraining_datasets_PanLUNA.py | 96 +++
process_raw_ecg.py | 121 ++++
28 files changed, 4540 insertions(+), 1 deletion(-)
create mode 100644 PanLUNA.md
create mode 100644 PanLUNA.py
create mode 100644 PanLUNA_finetune.yaml
create mode 100644 PanLUNA_pretrain.yaml
create mode 100644 dataset_types.yaml
create mode 100644 finetune_data_module_multimodal_PanLUNA.yaml
create mode 100644 finetune_data_module_unimodal_PanLUNA.yaml
create mode 100644 finetune_task_PanLUNA.py
create mode 100644 finetune_task_PanLUNA.yaml
create mode 100644 finetuning_multimodal_datasets_PanLUNA.py
create mode 100644 finetuning_unimodal_datasets_PanLUNA.py
create mode 100644 lead_positions.py
create mode 100644 make_code15_dataset.py
create mode 100644 make_cpsc2018_dataset.py
create mode 100644 make_csn_dataset.py
create mode 100644 make_hmc_dataset.py
create mode 100644 make_mimic_iv_dataset.py
create mode 100644 make_ptbxl_dataset.py
create mode 100644 make_pulsedb_dataset.py
create mode 100644 make_seed_vii_dataset.py
create mode 100644 make_siena_dataset.py
create mode 100644 make_wesad_dataset.py
create mode 100644 multiloader_data_module_PanLUNA.yaml
create mode 100644 pretrain_task_PanLUNA.py
create mode 100644 pretrain_task_PanLUNA.yaml
create mode 100644 pretraining_datasets_PanLUNA.py
create mode 100644 process_raw_ecg.py
diff --git a/PanLUNA.md b/PanLUNA.md
new file mode 100644
index 0000000..312f07d
--- /dev/null
+++ b/PanLUNA.md
@@ -0,0 +1,148 @@
+## PanLUNA
+
+PanLUNA is a self-supervised **pan-modal** biosignal foundation model that jointly processes **EEG, ECG, and PPG** within a single shared encoder. Extending LUNA's channel-unification module, PanLUNA treats multimodal channels as entries in a **unified query set augmented with sensor-type embeddings**, enabling efficient cross-modal early fusion while remaining inherently **robust to missing modalities** at inference time. Despite its compact 5.4M-parameter footprint, PanLUNA matches or exceeds models up to 57× larger, and supports quantization-aware INT8 deployment on the GAP9 ultra-low-power RISC-V microcontroller for continuous wearable monitoring.
+
+---
+
+### Default Input Assumptions
+
+All modalities are resampled to 256 Hz and segmented into non-overlapping 5-second windows with a patch size of 32 samples, unless a downstream task specifies otherwise (e.g., 10-second windows for ECG benchmarks, 30-second epochs for sleep staging).
+
+| Modality | Channels | Native Sampling Rate |
+|----------|----------|----------------------|
+| EEG | 20–22 (pre-training); 29 (Siena) | 250–512 Hz |
+| ECG | 12 (pre-training and cardiac benchmarks); 1 (HMC sleep staging) | 400–500 Hz |
+| PPG | 1 (PulseDB) | 125 Hz |
+
+Missing modalities are handled natively at inference without any architectural modification.
+
+---
+
+### Preprocessing
+
+A standardized modality-specific preprocessing pipeline is applied to all data:
+
+1. **Filtering**: Bandpass filtering with a 4th-order Butterworth filter, with modality-specific cutoffs: EEG 0.1–75 Hz; ECG 0.5–120 Hz; PPG 0.5–8 Hz. A notch filter (50 Hz or 60 Hz) is additionally applied.
+2. **Resampling**: All signals resampled to 256 Hz.
+3. **Normalization**: Per-channel z-score normalization, to account for the large amplitude differences across modalities (e.g., EEG in µV vs. ECG in mV).
+4. **Segmentation**: Non-overlapping 5-second windows during pre-training; task-specific windowing during fine-tuning (e.g., 10-second windows for PTB-XL/CSN, 30-second epochs for HMC sleep staging).
+
+---
+
+### Architecture Overview
+
+PanLUNA extends LUNA to the multimodal setting by generalizing topology invariance to cross-modal fusion.
+
+1. **Input Representation**
+ Channels from all modalities are concatenated along the channel dimension before entering the model. **Sensor-type embeddings** are introduced via a modality-specific lookup table, added to channel features at the input stage to distinguish sensing modalities. Channel positional encodings are modality-specific:
+ - **EEG**: Normalized 3D electrode coordinates encoded with sinusoidal embeddings (as in LUNA).
+ - **ECG**: Lead-angle estimates derived from [anatomical measurements on 30 body scans](https://www.ijcai.org/proceedings/2021/0495.pdf), constructing a spatial encoding analogous to EEG electrode positioning.
+ - **PPG**: Neutral coordinate (0, 0) assigned; the model relies on the sensor-type embedding for modality identification.
+
+2. **Patch Feature Extraction**
+ Signals are partitioned into short temporal patches and embedded via lightweight convolutional encoders combined with frequency features from the real-valued FFT. Patch-level features are augmented with positional encodings and sensor-type embeddings before entering the unification module.
+
+3. **Channel–Modality Unification Module**
+ Cross-attention aggregates information across both channels and modalities through a shared set of latent queries. This design removes the requirement for paired multimodal recordings during pre-training and enables training on large-scale unimodal corpora.
+
+4. **Temporal Transformer Encoder**
+ The unified latent sequence is processed by a patch-wise temporal Transformer with **Rotary Positional Embeddings (RoPE)** to capture long-range temporal dependencies. Self-attention operates on the fixed-size latent representation, fully decoupled from electrode count and modality composition.
+
+5. **Decoding and Classifier Heads**
+ During pre-training, a reconstruction decoder attends to encoder outputs to recover masked signal patches in a channel-specific manner. During fine-tuning this decoder is discarded and replaced by a lightweight aggregation query that pools the encoder output into a single representation, fed to a classification head. Three adaptation strategies are supported:
+ - **Full Fine-tuning (FF)**: All 5.4M parameters updated.
+ - **Frozen Encoder (FE)**: Backbone fixed; only the classification head (~400k parameters) trained.
+ - **LoRA**: Low-rank matrices (rank 16, ~180k parameters, ~580k total) injected into selected Transformer layers.
+
+---
+
+### Self-Supervised Learning (SSL) Objective
+
+PanLUNA is pre-trained with a **masked signal reconstruction** objective. A random subset of patch tokens is masked, and the reconstruction decoder is trained to recover the original signal patches in a channel-specific manner.
+
+---
+
+### Classification Protocols
+
+- **BC – Binary Classification**: Window-level binary label (e.g., normal vs. abnormal EEG on TUAB).
+- **MCC – Multi-class Classification**: Single-label classification per window (e.g., 5-stage sleep scoring on HMC).
+- **Multi-label Classification**: Multiple co-occurring labels per window (e.g., 19-label PTB-XL-Form ECG morphology).
+
+---
+
+### Model Variants
+
+| Variant | Parameters |
+|---------|------------|
+| PanLUNA | 5.4M |
+
+---
+
+### Training Setup
+
+- **Pre-training**
+ - **Datasets**: ~40,000 hours of heterogeneous biosignal data across five corpora:
+
+ | Dataset | Modality | Subjects | Channels | FS (Hz) | Window |
+ |---------|----------|----------|----------|---------|--------|
+ | TUEG | EEG | 14,987 | 20/22 | 250 | 5 s |
+ | Siena | EEG | 14 | 29 | 512 | 5 s |
+ | MIMIC-IV | ECG | 161,352 | 12 | 500 | 5 s |
+ | CODE-15% | ECG | 233,700 | 12 | 400 | 5 s |
+ | PulseDB | ECG, PPG | 5,361 | 2 | 125 | 5 s |
+
+ - **Objective**: Masked signal reconstruction; each modality can be used independently (no paired multimodal data required).
+
+- **Fine-tuning**
+ - Reconstruction decoder replaced with aggregation query + classification head; three adaptation strategies available (FF, FE, LoRA).
+ - **Loss**: Cross-Entropy for multi-class; BCE for multi-label classification.
+ - **Dataset splits**:
+ - **TUAB**: Official predefined train/val/test split.
+ - **PTB-XL (Super/Sub/Form/Rhythm)** and **CSN**: MERL ICML 2024 protocol.
+ - **HMC (sleep staging)**: Splits as in PhysioOmni.
+
+- **Quantization**
+ - Post-Training Quantization (PTQ) and Quantization-Aware Training (QAT) via Brevitas; evaluated at INT8, INT4, and INT2 weights. QAT runs for 15 fine-tuning epochs and recovers ≥96% of FP32 performance at INT8; INT2 weights achieve up to 16× storage reduction with graceful degradation.
+
+---
+
+### Edge Deployment (GAP9)
+
+PanLUNA is deployed on the **GAP9 ultra-low-power RISC-V microcontroller** (9-core cluster at 370 MHz, 1.5 MB L2 SRAM) using the BioFoundation edge framework with automated operator tiling, double-buffered DMA, NE16 acceleration, and custom tiled kernels for cross-attention projections and sensor-type embedding lookup.
+
+| Configuration | Channels | Window | MACs | Latency | Energy | Power |
+|---------------|----------|--------|------|---------|--------|-------|
+| ECG only | 12 | 10 s | 120.5 M | 325.6 ms | 18.8 mJ | 60.2 mW |
+| EEG + ECG | 5 | 30 s | 446.2 M | 1.206 s | 68.65 mJ | 56.9 mW |
+
+Streaming latency for ECG (patch-triggered): **450.6 ms** (125 ms acquisition + 325.6 ms compute). Estimated continuous monitoring battery life on a 300 mAh / 3.7 V wearable: **~24 days** (ECG-only), **~20 days** (multimodal sleep staging). This is, to our knowledge, the first deployment of a multimodal physiological FM on an ultra-low-power MCU.
+
+---
+
+### Results Summary
+
+**TUAB (Abnormal EEG Detection)**
+- PanLUNA (FF): **81.21%** balanced accuracy, 0.8999 AUC-PR, 0.8932 AUROC — outperforming LUNA-Base and LUNA-Large despite being 8–57× smaller.
+
+**HMC (Multimodal Sleep Staging, 5-class)**
+
+| Variant | Test Modality | Bal. Acc. | Cohen's κ | Weighted F1 |
+|---------|--------------|-----------|-----------|-------------|
+| PanLUNA (FF) | EEG | **0.7416** | **0.6946** | **0.7659** |
+| PanLUNA (FF) | EEG + ECG | 0.7002 | 0.6561 | 0.7383 |
+| PanLUNA (FF) | ECG only | 0.2977 | 0.1095 | 0.2876 |
+| PanLUNA (QAT INT8) | EEG + ECG | 0.7347 | 0.6913 | 0.7273 |
+
+State-of-the-art on HMC; surpasses PhysioOmni by +1.27% balanced accuracy.
+
+**PTB-XL / CSN (Cardiac Benchmarks, LoRA FE, FP32)**
+
+| Task | AUROC |
+|------|-------|
+| PTB-XL Super | 0.9083 |
+| PTB-XL Sub | 0.8880 |
+| PTB-XL Form | 0.8331 |
+| PTB-XL Rhythm | 0.9641 |
+| CSN | 0.9505 |
+
+State-of-the-art on PTB-XL Super and CSN. QAT INT8 recovers ≥96% of FP32 AUROC across all tasks; INT2 weights achieve up to 16× storage reduction with graceful degradation.
\ No newline at end of file
diff --git a/PanLUNA.py b/PanLUNA.py
new file mode 100644
index 0000000..2a16e3d
--- /dev/null
+++ b/PanLUNA.py
@@ -0,0 +1,311 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+
+import math
+from functools import partial
+from typing import List
+import numpy as np
+import random
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+import torch.nn.functional as F
+
+from timm.models.layers import trunc_normal_ as __call_trunc_normal_
+from timm.models.layers import Mlp
+
+from einops import rearrange
+from models.modules.rope_transformer_encoder_block import RotaryTransformerBlock
+from models.modules.frequency_embedder import FrequencyFeatureEmbedder
+from models.modules.lead_positions import ChannelEmbeddings, SensorEmbeddings
+
+def trunc_normal_(tensor, mean=0., std=1.):
+ __call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
+
+def nerf_positional_encoding(coords: torch.Tensor, embed_size: int) -> torch.Tensor:
+ """
+ coords: (N, C, 3)
+ Returns: (N, C, embed_size)
+ """
+ N, C, dim = coords.shape
+ device = coords.device
+ freqs = embed_size // (2 * dim)
+ leftover = embed_size - freqs * 2 * dim
+ freq_bands = 2.0 ** torch.arange(freqs, device=device).float()
+ scaled_coords = coords.unsqueeze(-1) * freq_bands.view(1, 1, 1, -1) # (N, C, dim, freqs)
+ sin_enc = torch.sin(scaled_coords) # (N, C, dim, freqs)
+ cos_enc = torch.cos(scaled_coords) # (N, C, dim, freqs)
+ encoded = torch.stack([sin_enc, cos_enc], dim=-1).permute(0, 1, 3, 2, 4).reshape(N, C, freqs * dim * 2)
+ if leftover > 0:
+ pad = torch.zeros(N, C, leftover, device=device, dtype=coords.dtype)
+ encoded = torch.cat([encoded, pad], dim=-1)
+ return encoded
+
+class PatchReconstructionHeadWithQueries(nn.Module):
+ def __init__(
+ self,
+ input_dim: int = 8,
+ embed_dim: int = 768,
+ num_heads: int = 8,
+ num_queries: int = 4,
+ ):
+ super().__init__()
+ self.input_dim = input_dim
+ self.embed_dim = embed_dim
+ self.reconstruction_shape = self.input_dim
+ self.num_queries = num_queries
+ # Projection from embed space to pixel space, according to type of input
+ self.decoder_pred = nn.TransformerDecoder(
+ nn.TransformerDecoderLayer(embed_dim, num_heads, dropout=0.0, batch_first=True, activation='gelu', dim_feedforward=int(embed_dim*4), norm_first=True),
+ num_layers=1
+ )
+ self.norm = nn.LayerNorm(embed_dim)
+ self.decoder_linear = Mlp(embed_dim, int(embed_dim*4), input_dim, act_layer=nn.GELU, drop=0.0) #nn.Linear(embed_dim, input_dim, bias=True)
+
+ def forward(self, enc, decoder_queries):
+ """
+ enc: [B, num_patches, embed_dim], embed_dim = Q*D
+ decoder_queries: [B*num_patches, num_channels, embed_dim]
+ """
+
+ B, num_patches, embed_dim = enc.shape
+ enc = rearrange(enc, 'B t (Q D) -> (B t) Q D', Q=self.num_queries)
+ out = self.decoder_pred(decoder_queries, enc) # (B*t, C, D)
+ out = self.norm(out)
+ out = self.decoder_linear(out) # (B*t, C, patch_size)
+ out = rearrange(out, '(B t) C P -> B C (t P)', B=B)
+ return out
+
+class ClassificationHeadWithQueries(nn.Module):
+ def __init__(
+ self,
+ input_dim: int = 8,
+ embed_dim: int = 768,
+ num_queries: int = 8,
+ num_heads: int = 8,
+ num_classes: int = 2,
+ ):
+ super().__init__()
+ self.input_dim = input_dim
+ self.embed_dim = int(embed_dim*num_queries)
+ self.reconstruction_shape = self.input_dim
+ self.decoder_attn = nn.MultiheadAttention(self.embed_dim, num_heads, batch_first=True, dropout=0.15)
+ self.decoder_ffn = Mlp(in_features=self.embed_dim, hidden_features=int(2*self.embed_dim), out_features=num_classes, act_layer=nn.GELU, drop=0.15)
+
+ self.learned_agg = nn.Parameter(torch.randn(1, 1, self.embed_dim), requires_grad=True)
+
+ def forward(self, x):
+ """
+ Output shape:
+ [B, num_tokens, in_chans, input_dim]
+ Args:
+ x: [B, num_tokens+1, embed_dim]
+ channel_embeddings: [B, in_chans, embed_dim]
+ """
+ B, num_patches, embed_dim = x.shape
+ decoder_queries = self.learned_agg.repeat(x.shape[0], 1, 1)
+
+ x = self.decoder_attn(query=decoder_queries, key=x, value=x)[0]
+ x = x[:,0,:]
+ x = self.decoder_ffn(x)
+ return x
+
+class CrossAttentionBlock(nn.Module):
+ def __init__(self, num_queries, input_embed_dim, output_embed_dim, num_heads, dropout_p=0.1, ff_dim=2048, pre_norm=True):
+ super(CrossAttentionBlock, self).__init__()
+ self.num_queries = num_queries
+ self.dropout_p = dropout_p
+ self.query_embed = nn.Parameter(torch.randn(1, num_queries, input_embed_dim), requires_grad=True) # Learnable queries
+ self.cross_attention = nn.MultiheadAttention(embed_dim=input_embed_dim, num_heads=num_heads, dropout=dropout_p,batch_first=True)
+ self.temparature = nn.Parameter(torch.tensor(1.0), requires_grad=False)
+ self.ffn = Mlp(input_embed_dim, ff_dim, output_embed_dim, act_layer=nn.GELU, drop=dropout_p, norm_layer=nn.LayerNorm)
+ self.keys_norm = nn.LayerNorm(input_embed_dim)
+ self.values_norm = nn.LayerNorm(input_embed_dim)
+ self.queries_norm = nn.LayerNorm(input_embed_dim)
+ self.query_self_attn = nn.TransformerEncoder(nn.TransformerEncoderLayer(input_embed_dim, nhead=num_heads, activation='gelu', dim_feedforward=ff_dim, batch_first=True, norm_first=True), num_layers=3)
+
+ def initialize_weights(self):
+ torch.nn.init.orthogonal_(self.query_embed, gain=1.0)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ # we use xavier_uniform following official JAX ViT:
+ torch.nn.init.xavier_normal_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def forward(self, x):
+ # x is the input with shape (batch_size*num_patches, num_channels, embed_dim)
+ batch_size, num_channels, _ = x.size()
+ queries = self.query_embed.repeat(batch_size,1,1)
+ queries = self.queries_norm(queries)
+ keys = self.keys_norm(x)
+ values = self.values_norm(x)
+
+ attention_out, attention_scores = self.cross_attention(query=queries,key=keys,value=values) # Shape: (batch_size*num_patches, num_queries, embed_dim)
+ attention_out = self.ffn(attention_out) + attention_out
+ attention_out = self.query_self_attn(attention_out)
+ return attention_out, attention_scores # Shape: (batch_size*num_patches, num_queries, embed_dim)
+
+class PatchEmbedNetwork(nn.Module):
+ def __init__(self, embed_dim=64, patch_size=40):
+ super(PatchEmbedNetwork, self).__init__()
+ self.patch_size = patch_size
+ self.embed_dim = embed_dim
+ self.in_channels = 1
+ self.out_channels = int(embed_dim//4)
+ self.groups = 4
+ self.kernel_size = int(patch_size//2)
+ self.proj_in = nn.Sequential(
+ nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels, kernel_size=(1, self.kernel_size-1), stride=(1, self.kernel_size//2), padding=(0, self.kernel_size//2-1)),
+ nn.GroupNorm(self.groups, self.out_channels),
+ nn.GELU(),
+
+ nn.Conv2d(in_channels=self.out_channels, out_channels=self.out_channels, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
+ nn.GroupNorm(self.groups, self.out_channels),
+ nn.GELU(),
+
+ nn.Conv2d(in_channels=self.out_channels, out_channels=self.out_channels, kernel_size=(1, 3), stride=(1, 1), padding=(0, 1)),
+ nn.GroupNorm(self.groups, self.out_channels),
+ nn.GELU(),
+ )
+ def forward(self, x):
+ """
+ x: (B, C, T)
+ output: (B, C*S, D) where S = T//patch_size, D = embed_dim
+ """
+ x = rearrange(x, 'B C (S P) -> B (C S) P', P=self.patch_size)
+ x = x.unsqueeze(1)
+ x = self.proj_in(x)
+ x = rearrange(x, 'B E CS D -> B CS (D E)')
+ return x
+
+class PanLUNA(nn.Module):
+ """
+ LUNA extension for multimodal processing.
+ """
+ def __init__(self, patch_size=40, num_queries=4,
+ embed_dim=64, depth=8, num_heads=2,
+ mlp_ratio=4., norm_layer=nn.LayerNorm,
+ drop_path=0.0, num_classes=0):
+ super().__init__()
+ self.embed_dim = embed_dim
+ self.num_queries = num_queries
+ self.patch_size = patch_size
+ self.patch_embed_size = embed_dim
+ self.num_heads = num_heads
+ self.num_classes = num_classes
+ self.depth = depth
+ self.patch_embed = PatchEmbedNetwork(embed_dim=self.embed_dim, patch_size=patch_size)
+ self.freq_embed = FrequencyFeatureEmbedder(embed_dim=self.embed_dim, patch_size=patch_size)
+ self.sensor_embed = SensorEmbeddings(embed_dim=self.embed_dim)
+ self.channel_location_embedder = nn.Sequential(
+ Mlp(in_features=int(self.patch_embed_size), out_features=int(self.patch_embed_size), hidden_features=int(self.patch_embed_size*2), act_layer=nn.GELU, drop=0.0, norm_layer=nn.LayerNorm),
+ )
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+ self.cross_attn = CrossAttentionBlock(num_queries=num_queries, input_embed_dim=self.embed_dim, output_embed_dim=self.embed_dim, num_heads=self.num_heads, ff_dim=int(mlp_ratio*self.embed_dim), pre_norm=True)
+ self.blocks = nn.ModuleList([
+ RotaryTransformerBlock(dim=int(self.embed_dim*self.num_queries), num_heads=int(self.num_heads*self.num_queries), mlp_ratio=mlp_ratio, qkv_bias=True, drop=0.0, attn_drop=0.0, drop_path=drop_path, norm_layer=norm_layer)
+ for i in range(depth)])
+ self.norm = norm_layer(int(self.embed_dim*self.num_queries))
+ if num_classes==0: # reconstruction (pre-training)
+ self.decoder_head = PatchReconstructionHeadWithQueries(input_dim=patch_size, embed_dim=self.embed_dim, num_heads=self.num_heads, num_queries=num_queries)
+ self.channel_emb = ChannelEmbeddings(self.embed_dim)
+ else: # classification
+ self.classifier = ClassificationHeadWithQueries(input_dim=patch_size, num_queries=num_queries, embed_dim=self.embed_dim, num_classes=num_classes, num_heads=self.num_heads)
+ self.mask_token.requires_grad = False # no use of mask token for classification
+ self.initialize_weights()
+
+ def initialize_weights(self):
+ self.cross_attn.initialize_weights()
+ trunc_normal_(self.mask_token, std=.02)
+ self.apply(self._init_weights)
+ self.fix_init_weight()
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ torch.nn.init.xavier_normal_(m.weight)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ def fix_init_weight(self):
+ def rescale(param, layer_id):
+ param.div_(math.sqrt(2.0 * layer_id))
+
+ for layer_id, layer in enumerate(self.blocks):
+ rescale(layer.attn.proj.weight.data, layer_id + 1)
+ rescale(layer.mlp.fc2.weight.data, layer_id + 1)
+
+ def prepare_tokens(self, x_signal, channel_locations, sensor_type, mask=None):
+ num_channels = channel_locations.shape[1]
+ num_patches_per_channel = x_signal.shape[-1] // self.patch_size
+ x_patched = self.patch_embed(x_signal)
+ freq_embed = self.freq_embed(x_signal)
+ x_patched = x_patched + freq_embed
+ x_masked = x_patched.clone() # (B, N, D), N = C * num_patches_per_channel
+ if mask is not None:
+ mask_tokens = self.mask_token.repeat(x_masked.shape[0], x_masked.shape[1], 1) # (B, N, D) N = C * num_patches_per_channel
+ mask = rearrange(mask, 'B C (S P) -> B (C S) P', P=self.patch_size) # (B, C, T) -> (B, N, P)
+ mask = (mask.sum(dim=-1) > 0).unsqueeze(-1).float() # (B, N, 1), since a patch is either fully masked or not
+ x_masked = torch.where(mask.bool(), mask_tokens, x_masked)
+ channel_min = torch.min(channel_locations, dim=1, keepdim=True)[0]
+ channel_max = torch.max(channel_locations, dim=1, keepdim=True)[0]
+ channel_locations = (channel_locations - channel_min) / (channel_max - channel_min + 1e-8)
+ if mask is not None:
+ channel_locations = channel_locations + torch.randn_like(channel_locations) * 0.02
+ channel_locations = nerf_positional_encoding(channel_locations, self.patch_embed_size)
+ channel_locations_emb = self.channel_location_embedder(channel_locations)
+
+ # Added part - sensor type
+ sensor_type = self.sensor_embed(sensor_type).repeat_interleave(num_patches_per_channel, dim=0)
+ x_tokenized = rearrange(x_masked, 'B (C t) D -> (B t) C D', C=num_channels)
+ channel_locations_emb = channel_locations_emb.repeat_interleave(num_patches_per_channel, dim=0)
+ x_tokenized = x_tokenized + channel_locations_emb + sensor_type
+
+ return x_tokenized, channel_locations_emb, sensor_type
+
+ def forward(self, x_signal, mask, channel_locations, sensor_type, channel_names=None, **kwargs):
+ x_original = x_signal
+ B, C, T = x_signal.shape
+ x, channel_locations_emb, sensor_type = self.prepare_tokens(x_signal, channel_locations, sensor_type, mask=mask)
+
+ x, attention_scores = self.cross_attn(x) # (B*num_patches, Q, D)
+ x = rearrange(x, '(B t) Q D -> B t (Q D)', B=B) # (B, num_patches, Q*D), Q*D is the new embed_dim
+ num_patches = x.shape[1]
+ for blk in self.blocks:
+ x = blk(x) # (B, N, D)
+ x_latent = self.norm(x) # (B, N, D)
+
+ if self.num_classes > 0:
+ x_classified = self.classifier(x_latent)
+ return x_classified, x_original
+ else:
+ channel_emb = self.channel_emb(channel_names)
+ channel_emb = channel_emb.repeat(num_patches, 1, 1)
+ decoder_queries = channel_locations_emb + channel_emb + sensor_type # (B*N, C, D)
+ x_reconstructed = self.decoder_head(x_latent, decoder_queries)
+ return x_reconstructed, x_original, attention_scores
diff --git a/PanLUNA_finetune.yaml b/PanLUNA_finetune.yaml
new file mode 100644
index 0000000..bd3ea35
--- /dev/null
+++ b/PanLUNA_finetune.yaml
@@ -0,0 +1,30 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+# @package _global_
+model:
+ _target_: models.PanLUNA.PanLUNA
+ patch_size: 32
+ embed_dim: 64
+ num_heads: 2
+ depth: 6
+ num_queries: 4
+ mlp_ratio: 4
+ drop_path: 0.1
+ num_classes: 2
\ No newline at end of file
diff --git a/PanLUNA_pretrain.yaml b/PanLUNA_pretrain.yaml
new file mode 100644
index 0000000..94d7311
--- /dev/null
+++ b/PanLUNA_pretrain.yaml
@@ -0,0 +1,30 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+# @package _global_
+model:
+ _target_: models.PanLUNA.PanLUNA
+ patch_size: 32
+ embed_dim: 64
+ num_heads: 2
+ depth: 6
+ num_queries: 4
+ mlp_ratio: 4
+ drop_path: 0.0
+ num_classes: 0
\ No newline at end of file
diff --git a/README.md b/README.md
index 5fe540a..4d71334 100644
--- a/README.md
+++ b/README.md
@@ -13,6 +13,9 @@
+
+
+
@@ -25,6 +28,9 @@
+
+
+
@@ -32,7 +38,7 @@
Copyright (C) 2025-2026 ETH Zurich, Switzerland. SPDX-License-Identifier: Apache-2.0. See LICENSE file for details.
-Authors: Thorir Mar Ingolfsson, Anna Tegon, Berkay Döner, Xiaying Wang, Matteo Fasulo, Danaé Broustail, Yawei Li & Luca Benini.
+Authors: Thorir Mar Ingolfsson, Anna Tegon, Berkay Döner, Xiaying Wang, Matteo Fasulo, Danaé Broustail, Marija Zelic, Yawei Li & Luca Benini.
## About
@@ -50,6 +56,7 @@ Looking for ready-to-use weights of models? We host them on Hugging Face:
- **LUNA** ([paper](https://arxiv.org/abs/2510.22257)) [](https://huggingface.co/PulpBio/LUNA)
- **TinyMyo** ([paper](https://arxiv.org/abs/2512.15729)) [](https://huggingface.co/PulpBio/TinyMyo)
- **LuMamba** ([paper](https://arxiv.org/abs/2603.19100)) [](https://huggingface.co/PulpBio/LuMamba)
+- **PanLUNA** ([paper](https://arxiv.org/pdf/2604.04297)) [](https://huggingface.co/PulpBio/PanLUNA)
#### Why FEMBA?
@@ -227,6 +234,46 @@ Tips:
- Ensure `data_module:train/test/val` are initialized with the compatible dataset class.
- Configuration file includes sufficient `#CHANGEME` tags and further instructions for a working example.
+## Why PanLUNA?
+
+* Extending LUNA's channel-unification mechanism from topology invariance to **cross-modal fusion**, jointly processing EEG, ECG, and PPG within a single shared encoder via sensor-type embeddings - no modality-specific backbones, no paired multimodal data required during pretraining.
+* Pretrained on ~40,000 hours of heterogeneous biosignal data (TUEG, Siena, MIMIC-IV, CODE-15%, PulseDB) with a masked signal reconstruction objective.
+* Strong performance on unimodal and multimodal experiments.
+
+➡️ Model hub: __https://huggingface.co/PulpBio/PanLUNA__ 📄 Model card: __[PanLUNA on Hugging Face](https://huggingface.co/PulpBio/PanLUNA)__ — variants, configs, and fine-tuning walkthrough. 📜 Weights license: CC BY-ND 4.0 (use + redistribute unmodified weights with attribution; no redistribution of modified weights) 🧑🍳 PR-gated improvements: If you fine-tune internally and want your variant to become an official PanLUNA release, open a PR with configs, logs, and evals. We'll review; if it looks good, we'll retrain/validate and publish an official PanLUNA checkpoint.
+
+### What you'll find on the hub
+
+* Pre-trained checkpoint.
+* Instructions to get started on fine-tuning experiments.
+
+### Quick download with `huggingface_hub`:
+
+```
+pip install huggingface_hub
+```
+
+```python
+from huggingface_hub import snapshot_download
+
+# downloads all pre-trained variants and safetensors into ./checkpoints/PanLUNA
+snapshot_download(repo_id="PulpBio/PanLUNA", repo_type="model", local_dir="checkpoints/PanLUNA")
+```
+
+Include the safetensors checkpoint path as input and run fine-tuning in the commandline:
+
+```bash
+python -u run_train.py +experiment=PanLUNA_finetune \
+ pretrained_safetensors_path=/absolute/path/to/checkpoints/PanLUNA/PanLUNA.safetensors
+```
+
+### Tips:
+
+* Data preprocessing scripts are provided in `/make_datasets` for various downstream datasets. Download the corresponding dataset, locate its preprocessing script by name matching, and adjust key parameters.
+* Adapt configuration file `config/experiment/PanLUNA_finetune.yaml` to your specific task with correct data module (for unimodal experiments `config/data_module/finetune_data_module_unimodal_PanLUNA` or multimodal experiments `config/data_module/finetune_data_module_multimodal_PanLUNA`), classification type (binary `bc`, multi-class `mcc` and mulit-label `mlp`) and change `model.num_classes` accordingly. For different fine-tuning strategies adjust `finetuning.mode` parameter with `lora` for Low-Rank Adaptation, `freeze_encoder` for frozen backbone or select `full` for complete adaptation.
+ * Ensure `data_module:train/test/val` are initialized with the compatible dataset class. Leverage `config/data_module/finetune_data_module_multimodal_PanLUNA:data_module.train.channel_groups` to specify channels available in the dataset and `channel_start`/`channel_end` to tweak the used subset.
+ * Configuration file includes sufficient `#CHANGEME` tags and further instructions for a working example.
+
## Features
- **Modular Design**: The repository is organized into modules for data loading, models, training tasks, and more, making it easy to extend and adapt for new research projects.
diff --git a/dataset_types.yaml b/dataset_types.yaml
new file mode 100644
index 0000000..34c5e58
--- /dev/null
+++ b/dataset_types.yaml
@@ -0,0 +1,33 @@
+dataset_types: # both pretraining datasets and unimodal finetuning datasets
+ tueg:
+ channels: [FP1-F7, F7-T3, T3-T5, T5-O1, FP2-F8, F8-T4, T4-T6, T6-O2, T3-C3, C3-CZ, CZ-C4, C4-T4, FP1-F3, F3-C3, C3-P3, P3-O1, FP2-F4, F4-C4, C4-P4, P4-O2, A1-T3, T4-A2]
+ location_fn: "eeg"
+ sensor_type: 1
+ siena:
+ channels: [FP1, FP2, F3, C3, P3, O1, F7, T3, T5, FC1, FC5, CP1, CP5, F9, FZ, CZ, PZ, F4, C4, P4, O2, F8, T4, T6, FC2,FC6, CP2, CP6, F10]
+ location_fn: "eeg"
+ sensor_type: 1
+ code15:
+ channels: [I, II, III, AVR, AVL, AVF, V1, V2, V3, V4, V5, V6]
+ location_fn: "ecg"
+ sensor_type: 0
+ mimicIV:
+ channels: [I, II, III, AVR, AVF, AVL, V1, V2, V3, V4, V5, V6]
+ location_fn: "ecg"
+ sensor_type: 0
+ pulsedb:
+ channels: [II, PPG]
+ location_fn: "ecg"
+ sensor_type: [0, 2]
+ ptbxl:
+ channels: [I, II, III, AVR, AVF, AVL, V1, V2, V3, V4, V5, V6]
+ location_fn: "ecg"
+ sensor_type: 0
+ csn:
+ channels: [I, II, III, AVR, AVL, AVF, V1, V2, V3, V4, V5, V6]
+ location_fn: "ecg"
+ sensor_type: 0
+ cpsc2018:
+ channels: [I, II, III, AVR, AVL, AVF, V1, V2, V3, V4, V5, V6]
+ location_fn: "ecg"
+ sensor_type: 0
\ No newline at end of file
diff --git a/finetune_data_module_multimodal_PanLUNA.yaml b/finetune_data_module_multimodal_PanLUNA.yaml
new file mode 100644
index 0000000..1f52129
--- /dev/null
+++ b/finetune_data_module_multimodal_PanLUNA.yaml
@@ -0,0 +1,109 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+# @package _global_
+# This is example for SEED-VII dataset
+# To adjust for other multimodal datasets replace channel_groups field with appropriate modalities and their channels
+# Adjust slicing with channel start and end indexes
+data_module:
+ _target_: data_module.finetune_data_module.FinetuneDataModule
+ name: "eeg"
+ cfg:
+ num_workers: ${num_workers}
+ batch_size: ${batch_size}
+ train:
+ _target_: 'datasets.finetuning_multimodal_datasets_PanLUNA.FinetuningMultimodal_Dataset'
+ hdf5_file: "#CHANGEME"
+ <<: &seed_channels
+ channel_groups:
+ eeg:
+ - FP1
+ - FPZ
+ - FP2
+ - AF3
+ - AF4
+ - F7
+ - F5
+ - F3
+ - F1
+ - FZ
+ - F2
+ - F4
+ - F6
+ - F8
+ - FT7
+ - FC5
+ - FC3
+ - FC1
+ - FCZ
+ - FC2
+ - FC4
+ - FC6
+ - FT8
+ - T7
+ - C5
+ - C3
+ - C1
+ - CZ
+ - C2
+ - C4
+ - C6
+ - T8
+ - TP7
+ - CP5
+ - CP3
+ - CP1
+ - CPZ
+ - CP2
+ - CP4
+ - CP6
+ - TP8
+ - P7
+ - P5
+ - P3
+ - P1
+ - PZ
+ - P2
+ - P4
+ - P6
+ - P8
+ - PO7
+ - PO5
+ - PO3
+ - POZ
+ - PO4
+ - PO6
+ - PO8
+ - CB1
+ - O1
+ - OZ
+ - O2
+ - CB2
+ ecg:
+ - II
+ channel_start: 0 # change for fraction of channels
+ channel_end: 63
+ val:
+ _target_: 'datasets.finetuning_multimodal_datasets_PanLUNA.FinetuningMultimodal_Dataset'
+ hdf5_file: "#CHANGEME"
+ <<: *seed_channels
+ test:
+ _target_: 'datasets.finetuning_multimodal_datasets_PanLUNA.FinetuningMultimodal_Dataset'
+ hdf5_file: "#CHANGEME"
+ <<: *seed_channels
\ No newline at end of file
diff --git a/finetune_data_module_unimodal_PanLUNA.yaml b/finetune_data_module_unimodal_PanLUNA.yaml
new file mode 100644
index 0000000..f6ccf55
--- /dev/null
+++ b/finetune_data_module_unimodal_PanLUNA.yaml
@@ -0,0 +1,49 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+# @package _global_
+defaults:
+ - dataset_types
+
+data_module:
+ _target_: data_module.finetune_data_module.FinetuneDataModule
+ cfg:
+ num_workers: ${num_workers}
+ batch_size: ${batch_size}
+ train:
+ _target_: 'datasets.finetuning_unimodal_datasets_PanLUNA.FinetuningUnimodal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12 # example
+ channels: ${dataset_types.csn.channels} # example
+ location_fn: ${dataset_types.csn.location_fn}
+ sensor_type: ${dataset_types.csn.sensor_type}
+ val:
+ _target_: 'datasets.finetuning_unimodal_datasets_PanLUNA.FinetuningUnimodal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.csn.channels}
+ location_fn: ${dataset_types.csn.location_fn}
+ sensor_type: ${dataset_types.csn.sensor_type}
+ test:
+ _target_: 'datasets.finetuning_unimodal_datasets_PanLUNA.FinetuningUnimodal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.csn.channels}
+ location_fn: ${dataset_types.csn.location_fn}
+ sensor_type: ${dataset_types.csn.sensor_type}
\ No newline at end of file
diff --git a/finetune_task_PanLUNA.py b/finetune_task_PanLUNA.py
new file mode 100644
index 0000000..050b945
--- /dev/null
+++ b/finetune_task_PanLUNA.py
@@ -0,0 +1,419 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+
+import torch
+import torch.nn as nn
+import pytorch_lightning as pl
+import hydra
+import torch_optimizer as torch_optim
+import torch.nn.functional as F
+from torchmetrics import MetricCollection
+from torchmetrics.classification import (
+ Accuracy, Precision, Recall, AUROC,
+ AveragePrecision, CohenKappa, F1Score
+)
+from peft import LoraConfig, get_peft_model, TaskType
+
+class ChannelWiseNormalize:
+ def __init__(self, eps=1e-8):
+ self.eps = eps
+
+ def __call__(self, tensor):
+ with torch.no_grad():
+ # tensor: (B, C, T)
+ mean = tensor.mean(dim=2, keepdim=True)
+ std = tensor.std(dim=2, keepdim=True)
+ return (tensor - mean) / (std + self.eps)
+
+class FinetuneTask(pl.LightningModule):
+ """
+ PyTorch Lightning module for fine-tuning a classification model, with support for:
+
+ - Classification types:
+ - `bc`: Binary Classification
+ - `ml`: Multi-Label Classification
+ - 'mc': Multi-Label Classification for TUAR
+ - `mcc`: Multi-Class Classification
+ - `mmc`: Multi-Class Multi-Output Classification
+ - `mlp`: Multi-Label Classification for ECG Tasks, e.g. PTB-XL Super (target of each sample of shape (12, 2560) is [0, 1, 0, 0, 1] having 5 possible labels for each sample and possibility of multiple labels per sample)
+
+ - Different fine-tuning strategies available: LoRA, frozen encoder and full.
+ - Metric logging during training, validation, and testing, including accuracy, precision, recall, F1 score, AUROC, and more
+ - Optional input normalization with configurable normalization functions
+ - Custom optimizer support including SGD, Adam, AdamW, and LAMB
+ - Learning rate schedulers with configurable scheduling strategies
+ - Layer-wise learning rate decay for fine-grained learning rate control across model blocks
+ """
+ def __init__(self, hparams):
+ """
+ Initialize the FinetuneTask module.
+
+ Args:
+ hparams (DictConfig): Hyperparameters and configuration loaded via Hydra.
+ """
+ super().__init__()
+ self.save_hyperparameters(hparams)
+ self.model = hydra.utils.instantiate(self.hparams.model)
+
+ print("--- MODEL MODULE NAMES ---")
+ for name, module in self.model.named_modules():
+ # Print only the linear or attention-related modules to keep the list short
+ if any(keyword in name for keyword in ['qkv', 'proj', 'attn', 'cross_attention']):
+ print(f"'{name}'")
+
+ self.num_classes = self.hparams.model.num_classes
+ self.classification_type = self.hparams.classification_type
+
+ # Input normalization
+ if self.hparams.input_normalization is not None and self.hparams.input_normalization.normalize:
+ self.normalize = True
+ self.normalize_fct = ChannelWiseNormalize()
+
+ # Loss function
+ if self.classification_type == "mc":
+ self.criterion = nn.BCEWithLogitsLoss()
+ elif self.classification_type == "mlp":
+ self.criterion = nn.BCEWithLogitsLoss()
+ else:
+ self.criterion = nn.CrossEntropyLoss()
+
+ # Classification mode detection
+ if not isinstance(self.num_classes, int):
+ raise TypeError("Number of classes must be an integer.")
+ elif self.num_classes < 2:
+ raise ValueError("Number of classes must be at least 1.")
+ elif self.num_classes == 2:
+ self.classification_task = "binary"
+ else:
+ self.classification_task = "multiclass"
+
+ # Metrics
+ if self.classification_type != "mlp":
+ label_metrics = MetricCollection([
+ Accuracy(task=self.classification_task, num_classes=self.num_classes, average="macro"),
+ Recall(task=self.classification_task, num_classes=self.num_classes, average="macro"),
+ Precision(task=self.classification_task, num_classes=self.num_classes, average="macro"),
+ F1Score(task=self.classification_task, num_classes=self.num_classes, average="macro"),
+ CohenKappa(task=self.classification_task, num_classes=self.num_classes)
+ ])
+ logit_metrics = MetricCollection([
+ AUROC(task=self.classification_task, num_classes=self.num_classes, average="macro"),
+ AveragePrecision(task=self.classification_task, num_classes=self.num_classes, average="macro"),
+ ])
+ elif self.classification_type == 'mlp':
+ label_metrics = MetricCollection([
+ Accuracy(task='multilabel', num_labels=self.num_classes, average="macro"),
+ Recall(task='multilabel', num_labels=self.num_classes, average="macro", zero_division=0),
+ Precision(task='multilabel', num_labels=self.num_classes, average="macro", zero_division=0),
+ F1Score(task='multilabel', num_labels=self.num_classes, average="macro", zero_division=0),
+ ])
+
+ logit_metrics = MetricCollection([
+ AUROC(task='multilabel', num_labels=self.num_classes, average="macro"),
+ AveragePrecision(task='multilabel', num_labels=self.num_classes, average="macro"),
+ ])
+
+ self.train_label_metrics = label_metrics.clone(prefix='train_')
+ self.val_label_metrics = label_metrics.clone(prefix='val_')
+ self.test_label_metrics = label_metrics.clone(prefix='test_')
+ self.train_logit_metrics = logit_metrics.clone(prefix='train_')
+ self.val_logit_metrics = logit_metrics.clone(prefix='val_')
+ self.test_logit_metrics = logit_metrics.clone(prefix='test_')
+
+ def load_pretrained_checkpoint(self, model_ckpt):
+ """
+ Load a pretrained model checkpoint and unfreeze specific layers for fine-tuning.
+ """
+ if model_ckpt is not None:
+ assert self.model.classifier is not None
+ print("Loading pretrained checkpoint")
+ ckpt = torch.load(model_ckpt)
+ state_dict = ckpt['state_dict']
+
+ # Remove decoder head and channel embedding weights since they are not needed for fine-tuning
+ state_dict = {k: v for k, v in state_dict.items() if 'decoder_head' not in k and "channel_emb" not in k}
+
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if k.startswith("model."):
+ new_state_dict[k[len("model."):]] = v
+ else:
+ new_state_dict[k] = v
+
+ missing, unexpected = self.model.load_state_dict(new_state_dict, strict=False)
+ print("missing:", missing)
+ print("unexpected:", unexpected)
+
+ if self.hparams.finetuning.mode == "lora":
+ print("=> Applying LoRA strategy...")
+ lora_cfg = self.hparams.finetuning.lora
+
+ config = LoraConfig(
+ task_type=TaskType.FEATURE_EXTRACTION,
+ r=lora_cfg.r,
+ lora_alpha=lora_cfg.alpha,
+ lora_dropout=lora_cfg.dropout,
+ target_modules=lora_cfg.target_modules
+ )
+
+ self.model = get_peft_model(self.model, config)
+ self.model.print_trainable_parameters()
+
+ for p in self.model.base_model.model.classifier.parameters():
+ p.requires_grad=True
+
+ elif self.hparams.finetuning.mode == "freeze_encoder":
+ print("=> Applying Freeze Encoder strategy...")
+
+ for name, param in self.model.named_parameters():
+ param.requires_grad = False
+ if 'classifier' in name:
+ param.requires_grad = True
+
+ else:
+ print("=> Applying Full Fine-tuning strategy...")
+
+ print("Pretrained model ready.")
+
+ def generate_fake_mask(self, batch_size, C, T):
+ """
+ Create a dummy mask tensor to simulate attention masking.
+
+ Args:
+ batch_size (int): Number of samples.
+ C (int): Number of channels.
+ T (int): Temporal dimension.
+
+ Returns:
+ torch.Tensor: Boolean mask tensor of shape (B, C, T).
+ """
+ return torch.zeros(batch_size, C, T, dtype=torch.bool).to(self.device)
+
+ def _step(self, X, mask, channel_locations, channel_names, sensor_type):
+ """
+ Perform forward pass and post-process predictions.
+
+ Args:
+ X (torch.Tensor): Input tensor.
+ mask (torch.Tensor): Attention mask tensor.
+
+ Returns:
+ dict: Dictionary containing predicted labels, probabilities, and logits.
+ """
+ y_pred_logits, _ = self.model(x_signal=X, mask=mask, channel_locations=channel_locations, sensor_type=sensor_type, channel_names=channel_names)
+
+ if self.classification_type in ("bc", "mcc", "ml"):
+ y_pred_probs = torch.softmax(y_pred_logits, dim=1)
+ y_pred_label = torch.argmax(y_pred_probs, dim=1)
+ elif self.classification_type in ("mc", "mlp"):
+ y_pred_probs = torch.sigmoid(y_pred_logits)
+ y_pred_label = torch.round(y_pred_probs)
+ elif self.classification_type == "mmc":
+ y_pred_logits = y_pred_logits.view(-1, 6)
+ y_pred_probs = torch.sigmoid(y_pred_logits)
+ y_pred_label = torch.argmax(y_pred_probs, dim=-1)
+
+ return {
+ 'label': y_pred_label,
+ 'probs': y_pred_probs,
+ 'logits': y_pred_logits,
+ }
+
+ def training_step(self, batch, batch_idx):
+ X, y = batch["input"], batch["label"]
+
+ channel_locations = batch["channel_locations"]
+ channel_names = batch.get("channel_names", None)
+ sensor_type = batch["sensor_type"]
+
+ if self.normalize:
+ X = self.normalize_fct(X)
+
+ mask = None
+ y_pred = self._step(X, mask, channel_locations, channel_names, sensor_type)
+ if self.classification_type == "mmc":
+ loss = self.criterion(y_pred['logits'], y)
+ elif self.classification_type in ("mc", "mlp"):
+ loss = self.criterion(y_pred['logits'], y.float())
+ else:
+ loss = self.criterion(y_pred['logits'], y)
+
+ self.train_label_metrics(y_pred['label'], y)
+ self.train_logit_metrics(self._handle_binary(y_pred['logits']), y)
+ self.log_dict(self.train_label_metrics, on_step=True, on_epoch=False)
+ self.log_dict(self.train_logit_metrics, on_step=True, on_epoch=False)
+ self.log('train_loss', loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+
+ lr = self.optimizers().param_groups[0]['lr']
+ self.log('lr', lr, on_step=True, on_epoch=False, prog_bar=False, logger=True)
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ X, y = batch["input"], batch["label"]
+
+ channel_locations = batch["channel_locations"]
+ channel_names = batch.get("channel_names", None)
+ sensor_type = batch["sensor_type"]
+
+ if self.normalize:
+ X = self.normalize_fct(X)
+
+ mask = None
+ y_pred = self._step(X, mask, channel_locations, channel_names, sensor_type)
+ if self.classification_type == "mmc":
+ y = y.view(-1)
+ loss = self.criterion(y_pred['logits'], y)
+ elif self.classification_type in ("mc", "mlp"):
+ loss = self.criterion(y_pred['logits'], y.float())
+ else:
+ loss = self.criterion(y_pred['logits'], y)
+
+ self.val_label_metrics(y_pred['label'], y)
+ self.val_logit_metrics(self._handle_binary(y_pred['logits']), y)
+ self.log_dict(self.val_label_metrics, on_step=False, on_epoch=True)
+ self.log_dict(self.val_logit_metrics, on_step=False, on_epoch=True)
+ self.log('val_loss', loss, prog_bar=True, logger=True, sync_dist=True)
+ return loss
+
+ def test_step(self, batch, batch_idx):
+ X, y = batch["input"], batch["label"]
+
+ channel_locations = batch["channel_locations"]
+ channel_names = batch.get("channel_names", None)
+ sensor_type = batch["sensor_type"]
+
+ if self.normalize:
+ X = self.normalize_fct(X)
+
+ mask = None
+ y_pred = self._step(X, mask, channel_locations, channel_names, sensor_type)
+ if self.classification_type == "mmc":
+ y = y.view(-1)
+ loss = self.criterion(y_pred['logits'], y)
+ elif self.classification_type in ("mc", "mlp"):
+ loss = self.criterion(y_pred['logits'], y.float())
+ else:
+ loss = self.criterion(y_pred['logits'], y)
+
+ self.test_label_metrics(y_pred['label'], y)
+ self.test_logit_metrics(self._handle_binary(y_pred['logits']), y)
+ self.log_dict(self.test_label_metrics, on_step=False, on_epoch=True)
+ self.log_dict(self.test_logit_metrics, on_step=False, on_epoch=True)
+ self.log('test_loss', loss, prog_bar=True, logger=True, sync_dist=True)
+ return loss
+
+ def on_train_epoch_end(self):
+ self.log_dict(self.train_label_metrics, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
+ self.log_dict(self.train_logit_metrics, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
+
+ def on_validation_epoch_end(self):
+ self.log_dict(self.val_label_metrics, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
+ self.log_dict(self.val_logit_metrics, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
+
+ def on_test_epoch_end(self):
+ self.log_dict(self.test_label_metrics, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
+ self.log_dict(self.test_logit_metrics, prog_bar=True, logger=True, sync_dist=True, on_step=False, on_epoch=True)
+
+
+ def lr_scheduler_step(self, scheduler, metric):
+ """
+ Custom scheduler step function for step-based LR schedulers
+ """
+ if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
+ scheduler.step(metric)
+ else:
+ scheduler.step_update(num_updates=self.global_step)
+
+ def configure_optimizers(self):
+ """
+ Configure the optimizer and learning rate scheduler.
+
+ Returns:
+ dict: Configuration dictionary with optimizer and LR scheduler.
+ """
+ num_blocks = self.hparams.model.depth
+ params_to_pass = []
+ base_lr = self.hparams.optimizer.lr
+ decay_factor = self.hparams.layerwise_lr_decay
+
+ for name, param in self.model.named_parameters():
+ lr = base_lr
+ if self.hparams.finetuning.mode == "full" or self.hparams.finetuning.mode == "freeze_encoder":
+ if "block." in name or 'norm_layers' in name:
+ block_nr = int(name.split('.')[1])
+ lr *= decay_factor ** (num_blocks - block_nr)
+ params_to_pass.append({"params": param, "lr": lr})
+ else:
+ if 'norm_layers' in name:
+ block_nr = int(name.split('.')[1])
+ lr *= decay_factor ** (num_blocks - block_nr)
+ params_to_pass.append({"params": param, "lr": lr})
+
+ if self.hparams.optimizer.optim == "SGD":
+ optimizer = torch.optim.SGD(params_to_pass, lr=base_lr, momentum=self.hparams.optimizer.momentum)
+ elif self.hparams.optimizer.optim == 'Adam':
+ optimizer = torch.optim.Adam(params_to_pass, lr=base_lr, weight_decay=self.hparams.optimizer.weight_decay)
+ elif self.hparams.optimizer.optim == 'AdamW':
+ optimizer = torch.optim.AdamW(params_to_pass, lr=base_lr, weight_decay=self.hparams.optimizer.weight_decay, betas=self.hparams.optimizer.betas)
+ elif self.hparams.optimizer.optim == 'LAMB':
+ optimizer = torch_optim.Lamb(params_to_pass, lr=base_lr)
+ else:
+ raise NotImplementedError("No valid optimizer name")
+
+ if self.hparams.scheduler_type == "multi_step_lr":
+ scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer)
+ lr_scheduler_config = {
+ "scheduler": scheduler,
+ "interval": "step",
+ "frequency": 1
+ }
+ elif self.hparams.scheduler_type == 'cosine':
+ scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer,
+ total_training_opt_steps=self.trainer.estimated_stepping_batches)
+ lr_scheduler_config = {
+ "scheduler": scheduler,
+ "interval": "step",
+ "frequency": 1
+ }
+ else:
+ scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer=optimizer, total_training_opt_steps= self.trainer.estimated_stepping_batches)
+ lr_scheduler_config = {
+ "scheduler": scheduler,
+ "interval": "step",
+ "frequency": 1
+ }
+
+
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
+
+ def _handle_binary(self, preds):
+ """
+ Special handling for binary classification probabilities.
+
+ Args:
+ preds (torch.Tensor): Logit outputs.
+
+ Returns:
+ torch.Tensor: Probabilities for the positive class.
+ """
+ if self.classification_task == 'binary' and self.classification_type != 'mc':
+ return preds[:, 1].squeeze()
+ else:
+ return preds
diff --git a/finetune_task_PanLUNA.yaml b/finetune_task_PanLUNA.yaml
new file mode 100644
index 0000000..5869f2e
--- /dev/null
+++ b/finetune_task_PanLUNA.yaml
@@ -0,0 +1,22 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+# @package _global_
+task:
+ _target_: 'tasks.finetune_task_PanLUNA.FinetuneTask'
\ No newline at end of file
diff --git a/finetuning_multimodal_datasets_PanLUNA.py b/finetuning_multimodal_datasets_PanLUNA.py
new file mode 100644
index 0000000..5a1ec2a
--- /dev/null
+++ b/finetuning_multimodal_datasets_PanLUNA.py
@@ -0,0 +1,126 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+import torch
+import h5py
+import numpy as np
+from typing import Optional
+
+from models.modules.lead_positions import (
+ map_lead_labels_to_angles,
+ get_channel_indices,
+ get_channel_locations
+)
+
+MODALITY_TO_SENSOR_TYPE = {"ecg": 0, "eeg": 1, "ppg": 2}
+KNOWN_MODALITIES = {"eeg", "ecg", "ppg"}
+
+def _compute_channel_location(
+ channel_names: list[str],
+ channel_modalities: list[str]
+):
+ """
+ Build a (n_channels, max_dim) location array.
+ Handles mixed EEG/ECG by zero-padding to the largest feature dimension.
+ """
+ raw_locs = []
+ for ch, mod in zip(channel_names, channel_modalities):
+ if mod == "eeg":
+ loc = np.stack(get_channel_locations([ch]), axis=0)
+ elif mod == "ecg" or mod == "ppg":
+ loc = map_lead_labels_to_angles([ch])
+ raw_locs.append(loc)
+
+ max_dim = max(loc.shape[1] for loc in raw_locs)
+ padded = [
+ np.pad(loc, ((0, 0), (0, max_dim - loc.shape[1])), mode="constant")
+ if loc.shape[1] < max_dim else loc
+ for loc in raw_locs
+ ]
+
+ return np.vstack(padded)
+
+class FinetuningMultimodal_Dataset(torch.utils.data.Dataset):
+ """
+ Unified class for mulitmodal finetuning datasets. Handles channel selection and channel location padding with support of hydra configuration.
+
+ Args:
+ hadf5_file: Path to the .h5 file.
+ channel_groups: All channels in the dataset organized as in the finetune_data_module_multimodal.yaml
+ channel_start: Starting index for channel slicing. Allows taking all or fraction of modalities.
+ channel_end: Ending index for channel slicing.
+ """
+ def __init__(
+ self,
+ hdf5_file: str,
+ channel_groups: dict[str, list[str]],
+ channel_start: Optional[int] = None,
+ channel_end: Optional[int] = None,
+ ):
+ super().__init__()
+ self._x_slice = slice(channel_start, channel_end)
+
+ # Flatten channel_groups into parallel list of names and modalities
+ self.channel_names = []
+ flat_modalities = []
+
+ for modality, names in channel_groups.items():
+ self.channel_names.extend(names)
+ if modality in KNOWN_MODALITIES:
+ flat_modalities.extend([modality] * len(names))
+
+ self.channel_indices = torch.tensor(get_channel_indices(self.channel_names), dtype=torch.long)
+
+ # Obtains channel locations
+ locs = _compute_channel_location(self.channel_names, flat_modalities)
+ self.channel_locations = torch.from_numpy(locs).float()
+ self.sensor_type = torch.tensor([MODALITY_TO_SENSOR_TYPE[m] for m in flat_modalities], dtype=torch.long)
+
+ # Open HDF5 and build flat
+ self.data = h5py.File(hdf5_file, "r")
+ self.keys = list(self.data.keys())
+ self.index_map = []
+
+ for key in self.keys:
+ group_size = len(self.data[key]['X'])
+ self.index_map.extend([(key, i) for i in range(group_size)])
+
+ def __len__(self):
+ return len(self.index_map)
+
+ def __getitem__(self, index):
+
+ group_key, sample_idx = self.index_map[index]
+ grp = self.data[group_key]
+
+ X = torch.FloatTensor(grp["X"][sample_idx])[self._x_slice, :]
+ label = torch.tensor(grp["y"][sample_idx], dtype=torch.long)
+
+ return_dict = {
+ "input": X,
+ "channel_names": self.channel_indices[self._x_slice],
+ "channel_locations": self.channel_locations[self._x_slice],
+ "sensor_type": self.sensor_type[self._x_slice],
+ "label": label
+ }
+
+ return return_dict
+
+
+
\ No newline at end of file
diff --git a/finetuning_unimodal_datasets_PanLUNA.py b/finetuning_unimodal_datasets_PanLUNA.py
new file mode 100644
index 0000000..19dd1a0
--- /dev/null
+++ b/finetuning_unimodal_datasets_PanLUNA.py
@@ -0,0 +1,97 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+import torch
+import h5py
+import numpy as np
+from models.modules.lead_positions import (
+ get_channel_indices,
+ get_channel_locations,
+ map_lead_labels_to_angles,
+)
+
+class FinetuningUnimodal_Dataset(torch.utils.data.Dataset):
+ """
+ Unified class for unimodal finetuning datasets.
+
+ Args:
+ hdf5_file: Path to the .h5 file.
+ channels: List of channels.
+ location_fn: "eeg" or "ecg".
+ sensor_type: 0 (ECG), 1(EEG), 2(PPG).
+ num_channels: Number of channels taken from total channel list.
+ """
+ def __init__(
+ self,
+ hdf5_file: str,
+ channels: list[str],
+ location_fn: str,
+ sensor_type: int | list[int],
+ num_channels: int | None = None,
+ ):
+ super().__init__()
+ channel_names = channels[:num_channels] if num_channels else channels
+ self.num_channels = len(channel_names)
+
+ if location_fn == "eeg":
+ locs = np.stack(get_channel_locations(channel_names), axis=0)
+ self.channel_locations = torch.from_numpy(locs).float()
+ else:
+ self.channel_locations = torch.FloatTensor(map_lead_labels_to_angles(channel_names))
+
+ self.channel_indices = torch.tensor(get_channel_indices(channel_names), dtype=torch.long)
+
+ if isinstance(sensor_type, list):
+ self.sensor_type = torch.tensor(sensor_type, dtype=torch.long)
+ else:
+ self.sensor_type = torch.full((self.num_channels, ), sensor_type, dtype=torch.long)
+
+ self.data = h5py.File(hdf5_file, "r")
+ self.keys = list(self.data.keys())
+ self.index_map = []
+
+ for key in self.keys:
+ if key == 'index_map':
+ continue
+ group_size = len(self.data[key]["X"])
+ self.index_map.extend([(key, i) for i in range(group_size)])
+
+ def __len__(self):
+ return len(self.index_map)
+
+ def __getitem__(self, index):
+
+ group_key, sample_idx = self.index_map[index]
+ grp = self.data[group_key]
+ X = torch.FloatTensor(grp["X"][sample_idx])
+ label = torch.tensor(grp["y"][sample_idx], dtype=torch.long)
+
+ item = {
+ "input": X,
+ "channel_names": self.channel_indices,
+ "channel_locations": self.channel_locations,
+ "sensor_type": self.sensor_type,
+ "label": label
+ }
+
+ return item
+
+ def __del__(self):
+ if hasattr(self, "data"):
+ self.data.close()
\ No newline at end of file
diff --git a/lead_positions.py b/lead_positions.py
new file mode 100644
index 0000000..cfe221b
--- /dev/null
+++ b/lead_positions.py
@@ -0,0 +1,212 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+
+
+
+import numpy as np
+import torch
+import torch.nn as nn
+import mne
+
+SEED_VII_CH_LIST = ['FP1', 'FPZ', 'FP2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'OZ', 'O2', 'CB2', 'ECG']
+
+MIMIC_IV_ECG_CHANNEL_LIST = [
+ 'I',
+ 'II',
+ 'III',
+ 'AVR',
+ 'AVF',
+ 'AVL',
+ 'V1',
+ 'V2',
+ 'V3',
+ 'V4',
+ 'V5',
+ 'V6'
+]
+
+SIENA_CHANNEL_LIST = [
+ "FP1",
+ "FP2",
+ "F3",
+ "C3",
+ "P3",
+ "O1",
+ "F7",
+ "T3",
+ "T5",
+ "FC1",
+ "FC5",
+ "CP1",
+ "CP5",
+ "F9",
+ "FZ",
+ "CZ",
+ "PZ",
+ "F4",
+ "C4",
+ "P4",
+ "O2",
+ "F8",
+ "T4",
+ "T6",
+ "FC2",
+ "FC6",
+ "CP2",
+ "CP6",
+ "F10",
+]
+
+TUEG_CHANNEL_LIST = [
+ "FP1-F7",
+ "F7-T3",
+ "T3-T5",
+ "T5-O1",
+ "FP2-F8",
+ "F8-T4",
+ "T4-T6",
+ "T6-O2",
+ "T3-C3",
+ "C3-CZ",
+ "CZ-C4",
+ "C4-T4",
+ "FP1-F3",
+ "F3-C3",
+ "C3-P3",
+ "P3-O1",
+ "FP2-F4",
+ "F4-C4",
+ "C4-P4",
+ "P4-O2",
+ "A1-T3",
+ "T4-A2",
+]
+
+VITALDB_CHANNEL_LIST = [
+ 'II',
+ 'PPG'
+]
+
+HMC_CHANNEL_LIST = [
+ "F4",
+ "C4",
+ "O2",
+ "C3",
+]
+
+all_channels = set()
+
+for ds in [
+ SEED_VII_CH_LIST,
+ MIMIC_IV_ECG_CHANNEL_LIST,
+ TUEG_CHANNEL_LIST,
+ SIENA_CHANNEL_LIST,
+ VITALDB_CHANNEL_LIST,
+ HMC_CHANNEL_LIST
+]:
+ for ch in ds:
+ all_channels.add(ch)
+CHANNEL_NAMES_TO_IDX = {ch: i for i, ch in enumerate(sorted(all_channels))}
+CHANNEL_IDX_TO_NAMES = {i: ch for ch, i in CHANNEL_NAMES_TO_IDX.items()}
+
+def get_channel_indices(channel_names):
+ indices = []
+ for name in channel_names:
+ indices.append(CHANNEL_NAMES_TO_IDX[name])
+ return indices
+
+def get_channel_names(channel_indices):
+ names = []
+ for idx in channel_indices:
+ names.append(CHANNEL_IDX_TO_NAMES[idx])
+ return names
+
+def get_channel_locations(channel_names):
+ if "-" in channel_names[0]:
+ names = list(set([part for ch in channel_names for part in ch.split('-')]))
+ else:
+ names = channel_names
+ ch_types = ['eeg'] * len(names) # Channel types
+ info = mne.create_info(ch_names=names, sfreq=256, ch_types=ch_types)
+ info = info.set_montage(mne.channels.make_standard_montage("standard_1005"),match_case=False,match_alias={'cb1': 'POO7', 'cb2': 'POO8'})
+ locs = []
+ for name in channel_names:
+ if name in TUEG_CHANNEL_LIST:
+ electrode1, electrode2 = name.split('-')
+ loc1 = info.get_montage().get_positions()['ch_pos'][electrode1]
+ loc2 = info.get_montage().get_positions()['ch_pos'][electrode2]
+ locs.append(((loc1 + loc2) / 2))
+ else:
+ locs.append(info.get_montage().get_positions()['ch_pos'][name])
+ return locs
+
+# We encode each ECG lead with 2-angle-tuple vector from https://www.ijcai.org/proceedings/2021/0495.pdf
+lead_labels_mapping = {
+ 'I': (np.pi/2, np.pi/2),
+ 'II': (5*np.pi/6, np.pi/2),
+ 'III': (5*np.pi/6, -np.pi/2),
+ 'AVR': (np.pi/3, -np.pi/2),
+ 'AVL': (np.pi/3, np.pi/2),
+ 'AVF': (np.pi, np.pi/2),
+ 'V1': (np.pi/2, -np.pi/18),
+ 'V2': (np.pi/2, np.pi/18),
+ 'V3': (np.pi/2, np.pi/18),
+ 'V4': (11*np.pi/20, np.pi/6),
+ 'V5': (8*np.pi/15, np.pi/3),
+ 'V6': (8*np.pi/15, np.pi/2),
+ 'PPG': (0.0, 0.0)
+}
+
+def map_lead_labels_to_angles(lead_labels):
+ """
+ Mapping lead labels to angles.
+
+ Args:
+ lead_labels (list): List of lead labels.
+
+ Returns:
+ mapping (np.array): Array of mapped leads to angles for specific lead list.
+ """
+ lead_labels_decoded = [lead.decode('utf-8') if isinstance(lead, bytes) else lead for lead in lead_labels]
+ mapping = np.array([lead_labels_mapping[lead] for lead in lead_labels_decoded], dtype=float)
+
+ return mapping
+
+class ChannelEmbeddings(nn.Module):
+ def __init__(self, embed_dim):
+ super(ChannelEmbeddings, self).__init__()
+ self.embeddings = nn.Embedding(len(CHANNEL_NAMES_TO_IDX), embed_dim)
+
+ def forward(self, indices):
+ return self.embeddings(indices)
+
+ def initialize_weights(self):
+ torch.init.normal_(self.embeddings.weight, std=0.02)
+
+class SensorEmbeddings(nn.Module):
+ def __init__(self, embed_dim):
+ super(SensorEmbeddings, self).__init__()
+ self.embeddings = nn.Embedding(3, embed_dim)
+
+ def forward(self, indices):
+ return self.embeddings(indices)
+
+ def initialize_weights(self):
+ torch.init.normal_(self.embeddings.weight, std=0.02)
\ No newline at end of file
diff --git a/make_code15_dataset.py b/make_code15_dataset.py
new file mode 100644
index 0000000..240b4eb
--- /dev/null
+++ b/make_code15_dataset.py
@@ -0,0 +1,146 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+import h5py
+import argparse
+import multiprocessing
+import tqdm
+import os
+
+import numpy as np
+from process_raw_ecg import preprocess_signal, time_segmenting
+
+sampling_freq = 400
+CODE15_channels = ['I', "II", "III", "AVR", "AVL", "AVF", "V1", "V2", "V3", "V4", "V5", "V6"]
+
+def dump_data(data_slice, output_dir, data_group_name, file_idx):
+ """
+ Write the data in H5 file.
+ """
+
+ data_group = data_slice
+ output_path = os.path.join(output_dir, f"CODE15_{file_idx}.h5")
+
+ with h5py.File(output_path, "a") as h5f:
+
+ if h5f.attrs.get("channel_names") is None:
+ h5f.attrs['channel_names'] = CODE15_channels
+
+ grp = h5f.create_group(data_group_name)
+ X_data = np.array(data_group, dtype=np.float16)
+ grp.create_dataset("X", data=X_data, dtype='float16')
+
+
+def process_single_file(file_path, output_dir, downsample_fs, split_signal, file_idx, worker_id):
+ """
+ Process a single HDF5 file of Code15 dataset.
+
+ Each sample has shape (4096, 12). Samples are either of length 10s or 7s.
+ Sampling frequency is 400Hz. When length is 7s long, there's 2800 sample and to pad to 4096 zeros are addded.
+ In case of 10s, there's no additional padding, but it's actually a signal.
+ To have consistency, I first strip 96 samples, (48 at the beginning and 48 at the end) - it's going to be around 0.1s for full signals and just zeros for 7s long.
+ And then process it.
+
+ Args:
+ file_path (str): Path to the file.
+ output_dir (str): Output directory to which we are writing the H5 files.
+ downsample_fs (int): Downsampling frequency.
+ split_signal (int): Lenght of the split to segment the signal. Assumes no padding, no overlap.
+ """
+
+ batch = 1024
+ with h5py.File(file_path, "r") as f:
+
+ # Extract dataset storing the ECGs
+ # We can't extract all the data in one np.array as it is too much
+ ecgs = f['tracings']
+ N = ecgs.shape[0]
+
+ # We need to do it in batches
+ for start in tqdm.tqdm(range(0, N, batch), desc=f"Worker {worker_id}", position=worker_id):
+ session_name = f"code15_{start}"
+ end = min(start+batch, N)
+
+ # Extract batch_size samples and convert to np.array
+ ecg_batch = np.array(ecgs[start:end])
+
+ # Transpose and take middle 4000 samples (discarding first and last 48)
+ ecg_batch = np.transpose(ecg_batch, axes=(0,2,1))[:, :, 48:4048]
+
+ # Process and split the signal
+ processed_ecg = preprocess_signal(ecg_batch, sampling_freq, 0.5, 120, downsample_fs, None)
+ reshaped_ecg = processed_ecg
+
+ # Split if split_signal provided
+ if split_signal is not None:
+
+ reshaped_ecg = np.expand_dims(processed_ecg, axis=0)
+ signal_splitted = time_segmenting(reshaped_ecg, split_signal, sampling_freq, downsample_fs, None)
+ merged_splits = np.concatenate(signal_splitted, axis=0)
+ reshaped_ecg = merged_splits.reshape(-1, merged_splits.shape[2], merged_splits.shape[3])
+
+ dump_data(reshaped_ecg, output_dir, session_name, file_idx)
+
+def loading_and_processing_parallel(files, output_dir, downsampling_fs, split_signal):
+ """
+ Load and process all the HDF5 files from Code15 dataset in the input directory in parallel.
+
+ Args:
+ files (list): List of HDF5 file paths to process.
+ output_dir (str): Directory to save processed pickle files.
+ downsampling_fs (int): Desired downsampling frequency.
+ split_signal (int or None): Length or segments to split the signals into (in seconds). If None, no splitting is done.
+ """
+ num_workers = 18
+ worker_args = []
+
+ # Prepare arguments for parallel processing
+ for idx, file in enumerate(files):
+ worker_args.append((file, output_dir, downsampling_fs, split_signal, idx))
+
+ print(f"Processing {len(worker_args)} groups of files:")
+
+ # Create the directory if it doesn't exist
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Use multiprocessing to parallelize the process
+ with multiprocessing.Manager() as manager:
+ with multiprocessing.Pool(num_workers) as pool:
+ pool.starmap(process_single_file, [(args[0], args[1], args[2], args[3], args[4], worker_id) for worker_id, args in enumerate(worker_args)])
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description="Process Code15 HDF5 files and save as .h5 files.")
+ parser.add_argument('--input_dir', type=str, default='#CHANGEME', help='Directory contraining MIMIC-IV-ECG files.')
+ parser.add_argument('--output_dir', type=str, default='#CHANGEME', help='Directory to save processed .h5 files.')
+ parser.add_argument('--downsampling_fs', type=int, default=256, help='Desired downsampling frequency.')
+ parser.add_argument('--split_signal', type=int, default=5, help='Length of segments to split the signals into (in seconds). If None, no splitting is done.')
+
+ args = parser.parse_args()
+ input_dir = args.input_dir
+ output_dir = args.output_dir
+ downsampling_fs = args.downsampling_fs
+ split_signal = args.split_signal
+
+ # List all HDF5 files that are stored in the directory
+ files = os.listdir(input_dir)
+ path_files = [os.path.join(input_dir, file) for file in files]
+
+ # Parallel writing
+ loading_and_processing_parallel(path_files, output_dir, downsampling_fs, split_signal)
\ No newline at end of file
diff --git a/make_cpsc2018_dataset.py b/make_cpsc2018_dataset.py
new file mode 100644
index 0000000..4c49c99
--- /dev/null
+++ b/make_cpsc2018_dataset.py
@@ -0,0 +1,147 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+import wfdb
+import numpy as np
+import pandas as pd
+import pickle
+import argparse
+import tqdm
+import ast
+import os
+import h5py
+from scipy.io import loadmat
+
+from process_raw_ecg import preprocess_signal
+
+def create_hdf5(source_dir, target_file, finetune=True, group_size=1000):
+ """
+ Lists all the pickle files in the source directory and writes them in the .h5 file.
+
+ Args:
+ source_dir (str): Source directory where pickle files are found.
+ target_file (str): Path to the target .h5 file.
+ finetune (bool): Whether data has label or not.
+ group_size (int): Group size in HDF5 saving.
+ """
+ # List all the files in the folder
+ files = sorted(os.listdir(source_dir))
+ data_group = []
+
+
+ with h5py.File(target_file, 'w') as h5f:
+ for i, file in enumerate(tqdm.tqdm(files)):
+ with open(os.path.join(source_dir, file), 'rb') as f:
+
+ sample = pickle.load(f)
+ data_group.append(sample)
+
+ if (i + 1) % group_size == 0 or i == len(files) - 1:
+
+ grp = h5f.create_group(f"data_group_{i // group_size}")
+ X_data = np.array([s['X'] for s in data_group])
+ grp.create_dataset("X", data=X_data)
+
+ if(finetune):
+ y_data = np.array([s['y'] for s in data_group])
+ grp.create_dataset("y", data=y_data)
+
+ data_group = []
+
+def process_csv_files(args, csv_file, split_type):
+ """
+ Process CPSC2018 ECG dataset based on a CSV file containing file paths.
+
+ Args:
+ args: Command line arguments.
+ csv_files (str): Path to the CSV file containing file paths.
+ split_type (str): Type of split ('train', 'val', 'test').
+ """
+ # Make directory for this split if it doesn't exist
+ os.makedirs(os.path.join(args.output_dir, split_type), exist_ok=True)
+
+ # Read the csv file
+ print(f"Processing {csv_file}...")
+ df = pd.read_csv(csv_file)
+
+ for idx, row in tqdm.tqdm(df.iterrows(), total=len(df)):
+
+ # Load the signal
+ try:
+ file_path = os.path.join(args.input_dir, f"{row['filename']}.mat")
+ signal = loadmat(file_path)['val'][:, :2500]
+ ecg = np.array(signal)
+
+ sampling_rate = 500 # hopefully, it's 500Hz
+
+ # Preprocess the signal and time-segment if needed
+ processed_ecg = preprocess_signal(ecg, sampling_rate, 0.5, 120, args.downsample_fs, None)
+
+ # Get the labels - there's 9 classes in total and they take last 9 columns of the csv
+ y = row.iloc[-9:].values.astype(np.int8)
+
+ # Write splits to pickle files
+ data_dict = {"X": processed_ecg, "y": y}
+ dump_path = os.path.join(args.output_dir, split_type, f'cpsc2018-{idx}.pkl')
+
+ with open(dump_path, "wb") as f:
+ pickle.dump(data_dict, f)
+
+ except Exception as e:
+ print(f"Skipped file {file_path}. Error occured: {e}")
+ continue
+
+def main_splitted(args):
+ """
+ Main function for processing CPASC2018 dataset and saving as pickle file.
+ Uses splits provided by MERL ICML 2024 paper for consistency in comparison (https://github.com/cheliu-computation/MERL-ICML2024/tree/main/finetune/data_split).
+ """
+ train_csv = os.path.join(args.csv_files_dir, "icbeb_train.csv")
+ val_csv = os.path.join(args.csv_files_dir, "icbeb_val.csv")
+ test_csv = os.path.join(args.csv_files_dir, "icbeb_test.csv")
+
+ # Create output directory if it doesn't exist
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Write each .csv file into pickle files with X and y
+ process_csv_files(args, train_csv, split_type="train")
+ process_csv_files(args, val_csv, split_type="val")
+ process_csv_files(args, test_csv, split_type="test")
+
+ # Finally write to HDF5
+ to_do = [f'train', f'val', f'test']
+ for td in to_do:
+ if os.path.exists(output_dir + '/' + td + ".h5"):
+ print(f"File {td}.h5 already exists!")
+ else:
+ print(f"Creating file {td}.h5.")
+ create_hdf5(output_dir + "/" + td, output_dir + "/" + td + ".h5")
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description="Process CPSC2018 ECG dataset and save as pickle files.")
+ parser.add_argument('--input_dir', type=str, default="#CHANGEME", help="Directory containing CPSC2018 WFDB files.")
+ parser.add_argument('--output_dir', type=str, default='#CHANGEME', help="Directory to save processed pickle files.")
+ parser.add_argument("--csv_files_dir", type=str, default='#CHANGEME')
+ parser.add_argument("--downsample_fs", type=int, default=256, help="Desired downsampling frequency. If None, no downsampling is done.")
+
+ args = parser.parse_args()
+ output_dir = args.output_dir
+
+ main_splitted(args)
\ No newline at end of file
diff --git a/make_csn_dataset.py b/make_csn_dataset.py
new file mode 100644
index 0000000..14e18ee
--- /dev/null
+++ b/make_csn_dataset.py
@@ -0,0 +1,157 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+import wfdb
+import numpy as np
+import pandas as pd
+import pickle
+import argparse
+import tqdm
+import ast
+import os
+import wfdb
+import h5py
+
+from process_raw_ecg import preprocess_signal, time_segmenting
+
+leads = ['I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6']
+
+def process_csv_files(args, csv_file, split_type):
+ """
+ Process Chapman-Shaoxing ECG dataset based on a CSV file containing file paths.
+
+ Args:
+ args: Command line arguments.
+ csv_file (str): Path to the CSV file containing file paths.
+ split_type (str): Type of split ('train', 'val', 'test').
+ """
+ # Make directory for this split if it doesn't exist
+ os.makedirs(os.path.join(args.output_dir, split_type), exist_ok=True)
+
+ # Read the csv file
+ print(f"Processing {csv_file}...")
+ df = pd.read_csv(csv_file)
+
+ for idx, row in tqdm.tqdm(df.iterrows(), total=len(df)):
+
+ # Load the signal
+ try:
+ file_path = os.path.join(args.input_dir, row['ecg_path'].lstrip('/')).split('.')[0]
+ signal, fields = wfdb.rdsamp(file_path)
+ ecg = np.array(signal).T # transpose to have shape (channels, samples)
+
+ sampling_rate = fields['fs']
+
+ # Preprocess the signal and time-segment if needed
+ processed_ecg = preprocess_signal(ecg, sampling_rate, 0.5, 120, args.downsample_fs, None)
+ #splits = time_segmenting(processed_ecg, 5, sampling_rate, args.downsample_fs, None)
+
+ # Get the labels - there's 38 classes in total and they take last 38 columns of the csv
+ y = row.iloc[-38:].values.astype(np.int8)
+
+ # Write splits to pickle files
+ #for num, split in enumerate(splits):
+ data_dict = {"X": processed_ecg, "y": y}
+ dump_path = os.path.join(args.output_dir, split_type, f'chapman-{idx}.pkl')
+
+ with open(dump_path, "wb") as f:
+ pickle.dump(data_dict, f)
+ except Exception as e:
+ print(f"Skipped file {file_path}. Error occured: {e}")
+ continue
+
+def main_splitted(args):
+ """
+ Main function for processing CSN dataset and saving as pickle file.
+ Uses splits provided by MERL ICML 2024paper for consistency in comparison (https://github.com/cheliu-computation/MERL-ICML2024/tree/main/finetune/data_split).
+ Processing is done with the respect to the the pipeline my code expects.
+ Filtering, downsampling to 256 Hz and segmenting in 5s pieces is done here.
+
+ Args:
+ args: Parsed arguments from argparse.
+ """
+ train_csv = os.path.join(args.csv_files_dir, 'chapman_train.csv')
+ val_csv = os.path.join(args.csv_files_dir, 'chapman_val.csv')
+ test_csv = os.path.join(args.csv_files_dir, 'chapman_test.csv')
+
+ # Create output directory if it doesn't exist
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ # Write each csv file into pickle files with X and y
+ process_csv_files(args, train_csv, split_type = 'train')
+ process_csv_files(args, val_csv, split_type = 'val')
+ process_csv_files(args, test_csv, split_type = 'test')
+
+ # Finally, write to HDF5
+ to_do = [f'train', f'val', f'test']
+ for td in to_do:
+ if os.path.exists(output_dir + '/' + td + ".h5"):
+ print(f"File {td}.h5 already exists!")
+ else:
+ print(f"Creating file {td}.h5.")
+ create_hdf5(output_dir + "/" + td, output_dir + "/" + td + ".h5")
+
+
+def create_hdf5(source_dir, target_file, finetune=True, group_size=1000):
+ """
+ Lists all the pickle files in the source directory and writes them in the .h5 file.
+
+ Args:
+ source_dir (str): Source directory where pickle files are found.
+ target_file (str): Path to the target .h5 file.
+ finetune (bool): Whether data has label or not.
+ group_size (int): Group size in HDF5 saving.
+ """
+ # List all the files in the folder
+ files = sorted(os.listdir(source_dir))
+ data_group = []
+
+
+ with h5py.File(target_file, 'w') as h5f:
+ for i, file in enumerate(tqdm.tqdm(files)):
+ with open(os.path.join(source_dir, file), 'rb') as f:
+
+ sample = pickle.load(f)
+ data_group.append(sample)
+
+ if (i + 1) % group_size == 0 or i == len(files) - 1:
+
+ grp = h5f.create_group(f"data_group_{i // group_size}")
+ X_data = np.array([s['X'] for s in data_group])
+ grp.create_dataset("X", data=X_data)
+
+ if(finetune):
+ y_data = np.array([s['y'] for s in data_group])
+ grp.create_dataset("y", data=y_data)
+
+ data_group = []
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description="Process Chapman-Shaoxing ECG dataset and save as pickle files.")
+ parser.add_argument('--input_dir', type=str, default='#CHANGEME', help='Directory containing Chapman-Shaoxing ECG WFDB files.')
+ parser.add_argument('--output_dir', type=str, default='#CHANGEME', help='Directory to save processed pickle files.')
+ parser.add_argument('--csv_files_dir', type=str, default='#CHANGEME')
+ parser.add_argument('--downsample_fs', type=int, default=256, help='Desired downsampling frequency. If None, no downsampling is done.')
+
+ args = parser.parse_args()
+ output_dir = args.output_dir
+
+ main_splitted(args)
+
\ No newline at end of file
diff --git a/make_hmc_dataset.py b/make_hmc_dataset.py
new file mode 100644
index 0000000..70f2c17
--- /dev/null
+++ b/make_hmc_dataset.py
@@ -0,0 +1,187 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+import os
+import mne
+import argparse
+import pickle
+import h5py
+
+import numpy as np
+
+from pathlib import Path
+from tqdm import tqdm
+
+standard = {'EEG F4-M1', 'EEG C4-M1', 'EEG O2-M1', 'EEG C3-M2', 'ECG'}
+
+mapping = {
+ 'EEG F4-M1': 'F4',
+ 'EEG C4-M1': 'C4',
+ 'EEG O2-M1': 'O2',
+ 'EEG C3-M2': 'C3',
+ 'ECG': 'I'
+}
+all_channels = list(standard)
+
+def match_annotations_and_files_and_split(input_dir):
+ """
+ Lists all .edf files in the directory and matches signal files with annotations.
+ Additionally, splits into train, val and test according to https://arxiv.org/pdf/2504.19596.
+ First 100 subjects goes to train, next 25 for validation and next 26 for test.
+
+ Args:
+ input_dir (pathlib.Path): Path to the directory with recordings.
+ """
+
+ files = input_dir.rglob('*.edf')
+
+ # .edf files are signal and _sleepscoring.edf are annotations
+ signal_files = []
+ annotation_files = []
+ for file in files:
+
+ file_name = file.parts[-1][:-4]
+
+ if len(file_name.split('_')) == 1: # if there's no underscore, it's signal file
+ signal_files.append(file)
+ else: # it's annotation file
+ annotation_files.append(file)
+
+ # Files are already sorted
+ train_files, train_ann = signal_files[:100], annotation_files[:100]
+ val_files, val_ann = signal_files[100:125], annotation_files[100:125]
+ test_files, test_ann = signal_files[125:], annotation_files[125:]
+
+ return train_files, train_ann, val_files, val_ann, test_files, test_ann
+
+def dump_pickle(X, y, file_name, split):
+ """
+ Writes data to a pickle file.
+
+ Args:
+ X (np.array): Data to write to pickle. Shape: (num_samples, num_channels, sample_length)
+ y (np.array): Label corresponding to each sample. 5-class single label classification. Shape: (num_samples,)
+ file_name (str): Name to give to a pickle file.
+ split (str): Either "train", "val" or "test".
+ """
+
+ data_dict = {"X": X, "y": y}
+ dump_path = os.path.join(output_dir, split, f"{file_name}.pkl")
+
+ # Write data to the path
+ with open(dump_path, "wb") as f:
+ pickle.dump(data_dict, f)
+
+ print(f"Dumped into {file_name}.pkl pickle!")
+
+def create_hdf5(output_dir, split):
+
+ folder_path = os.path.join(output_dir, split)
+ target_file = os.path.join(output_dir, f"{split}.h5")
+ files = sorted(os.listdir(folder_path))
+
+ with h5py.File(target_file, 'w') as h5f:
+ for i, file in enumerate(tqdm(files)):
+ with open(os.path.join(folder_path, file), 'rb') as f:
+
+ sample = pickle.load(f)
+
+ group = h5f.create_group(f"data_group_{i}")
+ X_data = np.array(sample["X"])
+ group.create_dataset("X", data=X_data)
+
+ y_data = np.array(sample['y'])
+ group.create_dataset("y", data=y_data)
+
+
+def process_split(signal_files, annotation_files, output_dir, split):
+ """
+ Process one split (train/val/test).
+ Adds annotations, extracts relevant channels, filters, and epochs based on sleep stage annotations.
+
+ Args:
+ signal_files (list[Path]): List of paths to the signal .edf files.
+ annotation_files (list[Path]): List of paths to the annotation .edf files.
+ output_dir (str): Output directory to write files.
+ split (str): Either "train", "val" or "test".
+ """
+ # Make directory for the split if it doesn't exist
+ os.makedirs(os.path.join(output_dir, split), exist_ok=True)
+
+ for signal, annot in zip(signal_files, annotation_files):
+
+ # Get filename for saving
+ file_name = signal.parts[-1][:-4]
+
+ # Load signal and annotations
+ raw = mne.io.read_raw_edf(signal, preload=True, verbose=False, include=all_channels)
+ ann = mne.read_annotations(annot)
+ raw.set_annotations(ann)
+
+ # Rename channel type to ecg for easier filtering // rename for positions
+ raw.set_channel_types({'ECG': 'ecg'})
+ raw.rename_channels(mapping)
+ raw.filter(l_freq=0.1, h_freq=75.0, picks='eeg', verbose='ERROR')
+ raw.filter(l_freq=0.5, h_freq=120.0, picks='ecg', verbose='ERROR')
+ raw.notch_filter(50.0, verbose="ERROR")
+
+ # Get annotations
+ events, event_id = mne.events_from_annotations(raw)
+ event_id = {k: v for k, v in event_id.items() if v not in [1,2]} # event_id markers 1,2 correspond to the light on/off moments
+ epochs = mne.Epochs(raw, events, event_id, tmin=0, tmax=30, baseline=None, verbose=True)
+ data = epochs.get_data()[:, :, :-1] # epochs returns 30*sampling_freq + 1 samples so we discard the last one
+
+ # We need to create labels - it's a third column from events
+ # Just need to discard the 1 and 2 corresponding to lights on/off
+ # Subtracting 3 to shift to 0-4 values
+ labels = events[:, 2]
+ remove_on_off = labels[(labels != 1) & (labels != 2)]
+ labels_corrected = remove_on_off - 3
+
+ # Write to pickle
+ dump_pickle(X=data, y=labels_corrected, file_name=file_name, split=split)
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input_dir', type=str, default='#CHANGEME')
+ parser.add_argument('--output_dir', type=str, default='#CHANGEME')
+
+ args = parser.parse_args()
+ input_dir = Path(args.input_dir)
+ output_dir = args.output_dir
+
+ # List all the files in the directory
+ files = [f for f in input_dir.rglob('*.edf')]
+ train_files, train_ann, val_files, val_ann, test_files, test_ann = match_annotations_and_files_and_split(input_dir)
+
+ # Make output directory if it doesn't exist
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Process each split
+ process_split(train_files, train_ann, output_dir, split="train")
+ process_split(val_files, val_ann, output_dir, split="val")
+ process_split(test_files, test_ann, output_dir, split="test")
+
+ # Write to HDF5
+ create_hdf5(output_dir, split="train")
+ create_hdf5(output_dir, split="val")
+ create_hdf5(output_dir, split="test")
+
+
diff --git a/make_mimic_iv_dataset.py b/make_mimic_iv_dataset.py
new file mode 100644
index 0000000..506f255
--- /dev/null
+++ b/make_mimic_iv_dataset.py
@@ -0,0 +1,127 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+import wfdb
+import numpy as np
+import pandas as pd
+import pickle
+import argparse
+import tqdm
+import ast
+import os
+import wfdb
+import subprocess
+import multiprocessing
+import h5py
+
+from process_raw_ecg import preprocess_signal, time_segmenting
+
+def dump_data(data_slice, output_dir, data_group_name, file_idx, channel_info):
+ output_path = os.path.join(output_dir, f"MIMIC-IV-ECG_12_channels_{file_idx}.h5")
+ with h5py.File(output_path, "a") as h5f:
+ if h5f.attrs.get("channel_name") is None:
+ h5f.attrs['channel_names'] = channel_info
+ grp = h5f.create_group(data_group_name)
+ X_data = np.array(data_slice, dtype=np.float16)
+ grp.create_dataset("X", data=X_data, dtype='float16')
+
+def process_single_file(file_paths, output_dir, downsample_fs, split_signal, file_idx, worker_id, batch_size=300):
+ batch_data = []
+ batch_idx = 0
+ channel_info = None
+
+ def flush_batch():
+ nonlocal batch_idx
+ if not batch_data:
+ return
+ concatenated = np.concatenate(batch_data, axis=0)
+ group_name = f"{worker_id}_{batch_idx}"
+ dump_data(concatenated, output_dir, group_name, file_idx, channel_info)
+ batch_data.clear()
+ batch_idx += 1
+
+ with tqdm.tqdm(total=len(file_paths), desc=f"Worker {worker_id}", position=worker_id) as pbar:
+ for file_path in file_paths:
+ record_path = file_path.split('.')[0]
+ record = wfdb.rdrecord(record_path)
+
+ raw_ecg = np.array(record.p_signal).T
+ sampling_rate = record.fs
+ channel_info = record.sig_name
+
+ processed_ecg = preprocess_signal(raw_ecg, sampling_rate, 0.5, 120, downsample_fs, None)
+
+ p_min = processed_ecg.min()
+ p_max = processed_ecg.max()
+ if (p_max - p_min) == 0:
+ pbar.update(1)
+ continue
+
+ processed_ecg = (processed_ecg - p_min) / (p_max - p_min)
+ merged_splits = processed_ecg.reshape(1, processed_ecg.shape[0], processed_ecg.shape[1])
+
+ if split_signal is not None:
+ signal_splitted = time_segmenting(merged_splits, split_signal, sampling_rate, downsample_fs, None)
+ merged_splits = np.concatenate(signal_splitted, axis=0)
+
+ batch_data.append(merged_splits)
+
+ if len(batch_data) >= batch_size:
+ flush_batch()
+
+ pbar.update(1)
+
+ flush_batch() # flush remaining records that didn't fill a full batch
+
+def loading_and_processing_parallel(files, output_dir, downsampling_fs, split_signal, max_sessions_per_file=30000):
+ num_workers = 27
+ worker_args = []
+
+ print(max_sessions_per_file)
+ for i in range(0, len(files), max_sessions_per_file):
+ worker_args.append((files[i:i+max_sessions_per_file], output_dir, downsampling_fs, split_signal, int(i / max_sessions_per_file)))
+
+ print(f"Processing {len(worker_args)} groups of files:")
+ os.makedirs(output_dir, exist_ok=True)
+
+ with multiprocessing.Manager() as manager:
+ with multiprocessing.Pool(num_workers) as pool:
+ pool.starmap(process_single_file, [
+ (args[0], args[1], args[2], args[3], args[4], worker_id, 300)
+ for worker_id, args in enumerate(worker_args)
+ ])
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description="Process MIMIC-IV-ECG .hea and .dat files and save as .h5 files.")
+ parser.add_argument('--input_dir', type=str, default='#CHANGEME', help='Directory containing MIMIC-IV-ECG files.')
+ parser.add_argument('--output_dir', type=str, default='#CHANGEME', help='Directory to save processed .h5 files.')
+ parser.add_argument('--downsampling_fs', type=int, default=256, help='Desired downsampling frequency.')
+ parser.add_argument('--split_signal', type=int, default=None, help='Length of segments to split the signals into (in seconds). If None, no splitting is done.')
+
+ args = parser.parse_args()
+
+ if os.path.exists('/scratch2/msc25h9/mimic-iv-ecg-files.pkl'):
+ with open('/scratch2/msc25h9/mimic-iv-ecg-files.pkl', 'rb') as f:
+ files = pickle.load(f)
+ else:
+ result = subprocess.run(["find", args.input_dir, "-maxdepth", str(5), "-type", "f", "-name", "*.hea"], stdout=subprocess.PIPE, text=True)
+ files = result.stdout.splitlines()
+
+ loading_and_processing_parallel(files, args.output_dir, args.downsampling_fs, args.split_signal)
\ No newline at end of file
diff --git a/make_ptbxl_dataset.py b/make_ptbxl_dataset.py
new file mode 100644
index 0000000..b1db63c
--- /dev/null
+++ b/make_ptbxl_dataset.py
@@ -0,0 +1,158 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+import pandas as pd
+import numpy as np
+
+import wfdb
+import ast
+import argparse
+import os
+import pickle
+import h5py
+
+from process_raw_ecg import preprocess_signal
+from itertools import chain
+from tqdm import tqdm
+
+def create_hdf5(source_dir, target_file, finetune=True, group_size=1000):
+ """
+ Lists all the pickle files in the source directory and writes them in the .h5 file.
+
+ Args:
+ source_dir (str): Source directory where pickle files are found.
+ target_file (str): Path to the target .h5 file.
+ finetune (bool): Whether data has label or not.
+ group_size (int): Group size in HDF5 saving.
+ """
+ # List all the files in the folder
+ files = sorted(os.listdir(source_dir))
+ data_group = []
+
+
+ with h5py.File(target_file, 'w') as h5f:
+ for i, file in enumerate(tqdm(files)):
+ with open(os.path.join(source_dir, file), 'rb') as f:
+
+ sample = pickle.load(f)
+ data_group.append(sample)
+
+ if (i + 1) % group_size == 0 or i == len(files) - 1:
+
+ grp = h5f.create_group(f"data_group_{i // group_size}")
+ X_data = np.array([s['X'] for s in data_group])
+ grp.create_dataset("X", data=X_data)
+
+ if(finetune):
+ y_data = np.array([s['y'] for s in data_group])
+ grp.create_dataset("y", data=y_data)
+
+ data_group = []
+
+def process_csv_files(args, csv_file, setup, split_type):
+ """
+ Process csv files and dump into pickle files. The csv files are taken from MERL ICML 2024 repository (https://github.com/cheliu-computation/MERL-ICML2024/tree/main/finetune/data_split).
+
+ Args:
+ args: Parsed arguments from argparse.
+ csv_file (str): Path to the csv file.
+ setup (str): One of ['super_class', 'sub_class', 'form' and 'rhythm'].
+ split (str): One of ['train', 'val', 'test'].
+ """
+ # Make output directory for this setup if it doesn't exist
+ os.makedirs(os.path.join(args.output_dir, setup, split_type), exist_ok=True)
+
+ # Read the csv file
+ print(f"Processing {csv_file}...")
+ df = pd.read_csv(csv_file)
+
+ for idx, row in tqdm(df.iterrows(), total=df.shape[0]):
+
+ # Load the signal, we always use the 500 Hz version
+ file_path = os.path.join(args.input_dir, row['filename_hr'])
+ signal, _ = wfdb.rdsamp(file_path)
+
+ # Preprocess the signal
+ processed_signal = preprocess_signal(signal.T, args.sampling_rate, 0.5, 120, args.downsampling_fs, None)
+
+ # Get the labels
+ if setup == 'super_class': # superclass setup has 5 classes / last 5 columns of every row are labels
+ y = row.iloc[-5:].values.astype(np.int8)
+ elif setup == 'sub_class': # subclass setup has 23 classes / last 23 columns of every row are labels
+ y = row.iloc[-23:].values.astype(np.int8)
+ elif setup == 'form': # form setup has 19 classes / last 19 columns of every row are labels
+ y = row.iloc[-19:].values.astype(np.int8)
+ elif setup == 'rhythm': # rhythm setup has 12 classes / last 12 columns of every row are labels
+ y = row.iloc[-12:].values.astype(np.int8)
+ else:
+ raise ValueError("Invalid setup provided.")
+
+ data_dict = {"X": processed_signal, "y": y}
+ dump_path = os.path.join(args.output_dir, setup, split_type, f'ptb-xl-{idx}.pkl')
+
+ with open(dump_path, "wb") as f:
+ pickle.dump(data_dict, f)
+
+def main_splitted(args, setup):
+ """
+ Main function for processing PTB_XL dataset and saving as pickle file.
+ Uses splits provided by MERL paper for consistency in comparison.
+ Processing is done with the respect to the the pipeline my code expects.
+ Filtering, downsampling and segmenting in 5s pieces is done here.
+
+ Args:
+ args: Parsed arguments from argparse.
+ setup (str): One of ['super_class', 'sub_class', 'form' and 'rhythm'].
+ """
+ train_csv = os.path.join(args.csv_files_dir, setup, f'ptbxl_{setup}_train.csv')
+ val_csv = os.path.join(args.csv_files_dir, setup, f'ptbxl_{setup}_val.csv')
+ test_csv = os.path.join(args.csv_files_dir, setup, f'ptbxl_{setup}_test.csv')
+
+ # Create output directory for this setup if it doesn't exist
+ os.makedirs(os.path.join(args.output_dir, setup), exist_ok=True)
+
+ # Write each csv into pickle files with X and y
+ process_csv_files(args, train_csv, setup, split_type='train')
+ process_csv_files(args, val_csv, setup, split_type='val')
+ process_csv_files(args, test_csv, setup, split_type='test')
+
+ # Finally, write to HDF5
+ to_do = ['train', 'val', 'test']
+ for td in to_do:
+ if os.path.exists(args.output_dir + '/' + setup + "/" + td + "/" + ".h5"):
+ print(f"File {td}.h5 already exists!")
+ else:
+ print(f"Creating file {td}.h5.")
+ create_hdf5(args.output_dir + "/" + setup + "/" + td, args.output_dir + "/" + setup + "/" + td + ".h5")
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description="Process PTB-XL dataset files and save as pickle.")
+ parser.add_argument('--input_dir', type=str, default='#CHANGEME')
+ parser.add_argument('--output_dir', type=str, default='#CHANGEME')
+ parser.add_argument('--csv_files_dir', type=str, default='#CHANGEME')
+ parser.add_argument('--sampling_rate', type=int, default=500)
+ parser.add_argument('--downsampling_fs', type=int, default=256)
+ parser.add_argument('--setup', type=str, default=5, help="Either super_class, sub_class, form or rhythm.")
+
+ args = parser.parse_args()
+ setup = args.setup
+
+ # Run main function for
+ main_splitted(args, setup=setup)
\ No newline at end of file
diff --git a/make_pulsedb_dataset.py b/make_pulsedb_dataset.py
new file mode 100644
index 0000000..9737728
--- /dev/null
+++ b/make_pulsedb_dataset.py
@@ -0,0 +1,152 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+import wfdb
+import numpy as np
+import pandas as pd
+import scipy.io
+import pickle
+import argparse
+import multiprocessing
+import tqdm
+import h5py
+import ast
+import os
+
+from process_raw_ecg import preprocess_signal, time_segmenting
+from pathlib import Path
+from sklearn.model_selection import train_test_split
+from mat73 import loadmat
+
+def dump_data(data_slice, output_dir, data_group_name, file_idx, channel_info, dataset):
+
+ data_group = data_slice
+ output_path = os.path.join(output_dir, f"{dataset}_2_channels_{file_idx}.h5")
+
+ with h5py.File(output_path, "a") as h5f:
+
+ if h5f.attrs.get("channel_names") is None:
+ h5f.attrs["channel_names"] = channel_info
+
+ grp = h5f.create_group(data_group_name)
+ X_data = np.array(data_group, dtype=np.float16)
+ grp.create_dataset("X", data=X_data, dtype='float16')
+
+ print(f"Saved {len(data_slice)} samples to {output_path}.")
+
+def process_single_file(file_paths, output_dir, sampling_rate, upsample_fs, split_signal, dataset, file_idx, worker_id):
+ """
+ Process a single ,mat file from VitalDB dataset. Friendly for multiprocessing.
+ """
+
+ with tqdm.tqdm(total=len(file_paths), desc=f"Worker {worker_id}", position=worker_id) as pbar:
+ for file_path in file_paths:
+ channel_info = ['II', 'PPG']
+
+ # Handle case if file is corrupted
+ try:
+ session_name = file_path.parts[-1][:-4]
+ data = loadmat(file_path)
+ raw_ecg = np.array(data['Subj_Wins']['ECG_Raw'])
+ raw_ppg = np.array(data['Subj_Wins']['PPG_Raw'])
+
+ except Exception as e:
+ print(f"Standard loading failed for {session_name}. Falling back to scipy.io.loadmat().")
+ try:
+ session_name = file_path.parts[-1][:-4]
+ data = scipy.io.loadmat(file_path)
+ unpacked = data['Subj_Wins'][0, 0]
+ raw_ecg = np.array(unpacked['ECG_Raw']).T
+ raw_ppg = np.array(unpacked['PPG_Raw']).T
+
+ except Exception as scipy_e:
+ print(f"Error loading {file_path} even with scipy.io. {scipy_e}.")
+ return None
+
+ # Preprocess signals and concatenate
+ processed_ecg = preprocess_signal(raw_ecg, sampling_rate, 0.5, 120, None, upsample_fs)
+ processed_ppg = preprocess_signal(raw_ppg, sampling_rate, 0.5, 12, None, upsample_fs)
+
+ if processed_ecg.ndim < 3 or processed_ppg.ndim < 3:
+ processed_ecg = processed_ecg.reshape(1, 1, -1)
+ processed_ppg = processed_ppg.reshape(1, 1, -1)
+
+ signals = np.concatenate((processed_ecg, processed_ppg), axis=1)
+
+ merged_splits = signals
+
+ # Time splitting
+ if split_signal is not None:
+
+ signal_splitted = time_segmenting(signals, split_signal, sampling_rate, None, upsample_fs)
+ merged_splits = np.concatenate(signal_splitted, axis=0)
+
+ dump_data(merged_splits, output_dir, session_name, file_idx, channel_info, dataset)
+
+ pbar.update(1)
+
+def loading_and_processing_parallel(files, output_dir, sampling_rate, upsampling_fs, split_signal, dataset, max_sessions_per_file=300):
+ """
+ Load and process all .mat files in the input directory in parallel.
+
+ Args:
+ files (list): List of .mat filenames to process.
+ output_dir (str): Directory to save processed pickle files.
+ sampling_rate (int): Original sampling rate of the signals.
+ upsampling_fs (int): Desired upsampling frequency.
+ split_signal (int or None): Length of segments to split the signals into (in seconds). If None, no splitting is done.
+ dataset (str): Dataset name.
+ """
+ num_workers = 4
+ worker_args = []
+
+ # Prepare arguments for parallel processing
+ for i in range(0, len(files), max_sessions_per_file):
+ worker_args.append((files[i:i+max_sessions_per_file], output_dir, sampling_rate, upsampling_fs, split_signal, dataset, int(i / max_sessions_per_file)))
+
+ print(f"Processing {len(worker_args)} groups of files:")
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Use multiprocessing to parallelize the process
+ with multiprocessing.Manager() as manager:
+ with multiprocessing.Pool(num_workers) as pool:
+ pool.starmap(process_single_file, [(args[0], args[1], args[2], args[3], args[4], args[5], args[6], worker_id) for worker_id, args in enumerate(worker_args)])
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser(description="Process VitalDB .mat files and save as pickle files.")
+ parser.add_argument('--dataset', type=str, default='VitalDB', help='Either VitalDB or MimicDB.')
+ parser.add_argument('--input_dir', type=str, default='#CHANGEME', help='Directory containing VitalDB .mat files.')
+ parser.add_argument('--output_dir', type=str, default='#CHANGEME', help='Directory to save processed pickle files.')
+ parser.add_argument('--sampling_rate', type=int, default=125, help='Original sampling rate of the signals.')
+ parser.add_argument('--upsampling_fs', type=int, default=256, help='Desired upsampling frequency.')
+ parser.add_argument('--split_signal', type=int, default=5, help='Length of segments to split the signals into (in seconds). If None, no splitting is done.')
+
+ args = parser.parse_args()
+ dataset = args.dataset
+ input_dir = Path(args.input_dir)
+ output_dir = Path(args.output_dir)
+ sampling_rate = args.sampling_rate
+ upsampling_fs = args.upsampling_fs
+ split_signal = args.split_signal
+
+ # List all .mat files in the input directory
+ files = [f for f in input_dir.rglob('*.mat')]
+ loading_and_processing_parallel(files, output_dir, sampling_rate, upsampling_fs, split_signal, dataset)
+
diff --git a/make_seed_vii_dataset.py b/make_seed_vii_dataset.py
new file mode 100644
index 0000000..6811975
--- /dev/null
+++ b/make_seed_vii_dataset.py
@@ -0,0 +1,202 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+import mne
+import argparse
+import os
+import numpy as np
+import csv
+import datetime
+import pickle
+import h5py
+from tqdm import tqdm
+# Code taken from SEED-VII instructions
+sampling_freq=256
+
+# We decide to label
+# neutral = 0
+# happy = 1
+# sad = 2
+# disgust = 3
+# anger = 4
+# fear = 5
+# surprise = 6
+
+per_session_labels = {
+ '1': [1, 0, 3, 2, 4, 4, 2, 3, 0, 1, 1, 0, 3, 2, 4, 4, 2, 3, 0, 1],
+ '2': [4, 2, 5, 0, 6, 6, 0, 5, 2, 4, 4, 2, 5, 0, 6, 6, 0, 5, 2, 4],
+ '3': [1, 6, 3, 5, 4, 4, 5, 3, 6, 1, 1, 6, 3, 5, 4, 4, 5, 3, 6, 1],
+ '4': [3, 2, 5, 6, 1, 1, 6, 5, 2, 3, 3, 2, 5, 6, 1, 1, 6, 5, 2, 3]
+}
+
+channels = ['FP1', 'FPZ', 'FP2', 'AF3', 'AF4', 'F7', 'F5', 'F3', 'F1', 'FZ', 'F2', 'F4', 'F6', 'F8', 'FT7', 'FC5', 'FC3', 'FC1', 'FCZ', 'FC2', 'FC4', 'FC6', 'FT8', 'T7', 'C5', 'C3', 'C1', 'CZ', 'C2', 'C4', 'C6', 'T8', 'TP7', 'CP5', 'CP3', 'CP1', 'CPZ', 'CP2', 'CP4', 'CP6', 'TP8', 'P7', 'P5', 'P3', 'P1', 'PZ', 'P2', 'P4', 'P6', 'P8', 'PO7', 'PO5', 'PO3', 'POZ', 'PO4', 'PO6', 'PO8', 'CB1', 'O1', 'OZ', 'O2', 'CB2', 'ECG']
+
+def process_and_split(input_dir, output_dir, file_dict):
+ """
+ Process sessions and place it in correct split (train/val/test). Following splits from PhysioOmni.
+ 20 subjects - 4 sessions per each - 20 trials in each
+ First 10 trials from each session go to the train split.
+ Next 5 trials from each session go to the val split and last 5 to the test split.
+ """
+ collect_num = 0
+ for key, value in file_dict.items():
+ for raw_file in value:
+ print("Starting with path:", raw_file)
+ # Prepare the path to the file
+ raw_path = os.path.join(input_dir, raw_file)
+ raw = mne.io.read_raw_cnt(raw_path, ecg=['ECG'])
+
+ # Drop useless channels
+ raw.drop_channels(['M1', 'M2', 'HEO', 'VEO'])
+ raw.load_data()
+
+ # Preprocessing
+ raw.filter(l_freq=0.1, h_freq=75.0, picks='eeg', verbose='ERROR')
+ raw.filter(l_freq=0.5, h_freq=120.0, picks='ecg', verbose='ERROR')
+ raw.notch_filter(50.0, verbose='ERROR')
+
+ if raw.info['sfreq'] != sampling_freq:
+ raw.resample(sampling_freq)
+
+ # Get triggers for each trial
+ trigger, _ = mne.events_from_annotations(raw)
+ data, times = raw.get_data(return_times=True)
+
+ t = trigger[:, 0]
+
+ # We need to separatrly handle files authors specify are not accurately handled with triggers
+ if raw_file == "14_20221015_1.cnt":
+ t = []
+ start = datetime.datetime.strptime('14:25:34', '%H:%M:%S')
+ with open('#CHANGEME') as f:
+ trigger = csv.reader(f)
+ for row in trigger:
+ end = datetime.datetime.strptime(row[1].split(' ')[-1], '%H:%M:%S.%f')
+ time_diff = end.timestamp() - start.timestamp()
+ t.append(int(round(time_diff * sampling_freq)))
+ elif raw_file == "9_20221111_3.cnt":
+ t = []
+ start = datetime.datetime.strptime('14:01:27', '%H:%M:%S')
+ with open(os.path.join('#CHANGEME')) as f:
+ trigger = csv.reader(f)
+ for row in trigger:
+ end = datetime.datetime.strptime(row[1].split(' ')[-1], '%H:%M:%S.%f')
+ time_diff = end.timestamp() - start.timestamp()
+ t.append(int(round(time_diff * sampling_freq)))
+
+ session_idx = int(raw_file[-5])
+ for i in range(20):
+ # Get the data between two triggers
+ preprocessed_clip = data[:, t[2 * i]:t[2 * i + 1]]
+ num = preprocessed_clip.shape[1] // sampling_freq
+ collect_num += num
+
+ # We want to fetch int number of seconds and reshape such that each second is segment
+ signal = preprocessed_clip[:, :num * sampling_freq]
+ split_signal = signal.reshape(signal.shape[0], num, sampling_freq).transpose(1, 0, 2)
+
+ for idx in range(num):
+ data_dict = {"X": split_signal[idx, :, :], 'y': per_session_labels[str(session_idx)][i]}
+ dump_name = f"{key}_seed_{session_idx}_{i}_{idx}.pkl"
+
+ # first 10 trials from session go to train split
+ if i < 10:
+ dump_path = os.path.join(output_dir, "train", dump_name)
+
+ elif i < 15 and i >= 10:
+ dump_path = os.path.join(output_dir, "val", dump_name)
+
+ else:
+ dump_path = os.path.join(output_dir, "test", dump_name)
+
+ with open(dump_path, "wb") as f:
+ pickle.dump(data_dict, f)
+ print("Current sample count:", collect_num)
+
+def create_hdf5(source_dir, target_file, finetune=True, group_size=1000):
+ """
+ Lists all the pickle files in the source directory and writes them in the .h5 file.
+
+ Args:
+ source_dir (str): Source directory where pickle files are found.
+ target_file (str): Path to the target .h5 file.
+ finetune (bool): Whether data has label or not.
+ group_size (int): Group size in HDF5 saving.
+ """
+ # List all the files in the folder
+ files = sorted(os.listdir(source_dir))
+ data_group = []
+
+
+ with h5py.File(target_file, 'w') as h5f:
+ for i, file in enumerate(tqdm(files)):
+ with open(os.path.join(source_dir, file), 'rb') as f:
+
+ sample = pickle.load(f)
+ data_group.append(sample)
+
+ if (i + 1) % group_size == 0 or i == len(files) - 1:
+
+ grp = h5f.create_group(f"data_group_{i // group_size}")
+ X_data = np.array([s['X'] for s in data_group])
+ grp.create_dataset("X", data=X_data)
+
+ if(finetune):
+ y_data = np.array([s['y'] for s in data_group])
+ print(y_data)
+ grp.create_dataset("y", data=y_data)
+
+ data_group = []
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input_dir', type=str, default='#CHANGEME')
+ parser.add_argument('--output_dir', type=str, default='#CHANGEME')
+
+ # Parse arguments
+ args = parser.parse_args()
+ input_dir = args.input_dir
+ output_dir = args.output_dir
+
+ # List all files in the directory and prepare per subjects
+ raw_files = os.listdir(input_dir)
+ file_dict = {}
+
+ for file in raw_files:
+ if file[:-15] not in file_dict.keys():
+ file_dict[file[:-15]] = [file]
+ else:
+ file_dict[file[:-15]].append(file)
+
+ os.makedirs(os.path.join(output_dir, "train"), exist_ok=True)
+ os.makedirs(os.path.join(output_dir, "val"), exist_ok=True)
+ os.makedirs(os.path.join(output_dir, "test"), exist_ok=True)
+
+ process_and_split(input_dir, output_dir, file_dict)
+
+ # Finally, write to HDF5
+ to_do = ['train', 'val', 'test']
+ for td in to_do:
+ if os.path.exists(output_dir + '/' + td + '.h5'):
+ print(f"File {td}.h5 already exists!")
+ else:
+ print(f"Creating file {td}.h5.")
+ create_hdf5(output_dir + "/" + td, output_dir + "/" + td + ".h5")
+
\ No newline at end of file
diff --git a/make_siena_dataset.py b/make_siena_dataset.py
new file mode 100644
index 0000000..5a3dcc1
--- /dev/null
+++ b/make_siena_dataset.py
@@ -0,0 +1,124 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+import mne
+import argparse
+import tqdm
+import h5py
+import os
+
+import numpy as np
+
+from pathlib import Path
+
+standard = {
+ 'EEG FP1', 'EEG F3', 'EEG C3', 'EEG P3', 'EEG O1', 'EEG F7', 'EEG T3', 'EEG T5', 'EEG FC1', 'EEG FC5',
+ 'EEG CP1', 'EEG CP5', 'EEG F9', 'EEG FZ', 'EEG CZ', 'EEG PZ', 'EEG FP2', 'EEG F4', 'EEG C4', 'EEG P4', 'EEG O2',
+ 'EEG F8', 'EEG T4', 'EEG T6', 'EEG FC2', 'EEG FC6', 'EEG CP2', 'EEG CP6', 'EEG F10'
+}
+
+all_channels = list(standard)
+sampling_freq = 256 # target sampling frequency
+segment_length = 5 # segment length in seconds
+
+def create_hdf5(sliced_data, output_dir, session_name):
+ """
+ Create HDF5 file from sliced data.
+
+ Args:
+ sliced_data (np.ndarray): Sliced data of shape (n_channels, n_intervals, interval_size).
+ output_dir (str or Path): Directory to save the HDF5 file.
+ session_name (str): Name of the session/file.
+ """
+ os.makedirs(output_dir, exist_ok=True)
+ target_path = os.path.join(output_dir, f"SIENA_29_channels.h5")
+
+ with h5py.File(target_path, "a") as hdf5_file:
+ group = hdf5_file.create_group(session_name)
+ X_data = np.array(sliced_data, dtype=np.float16)
+ group.create_dataset("X", data=X_data, dtype='float16')
+
+def process_and_save_files_to_hdf5(file_paths, output_dir):
+ """
+ Process EDF files, extract standard EEG channels, filter, downsample, segment and save to HDF5.
+
+ Args:
+ file_paths (list of Path): List of paths to EDF files.
+ output_dir (str or Path): Directory to save processed HDF5 files.
+ """
+
+ for file_path in tqdm.tqdm(file_paths):
+
+ print(f"Processing file: {file_path.name}")
+ session_name = file_path.parts[-1][:-4]
+ # Load raw EDF file
+ raw = mne.io.read_raw_edf(file_path, preload=True, verbose=False)
+
+ # First map all channel names to upper-case to have consistency
+ upper_mapping = {ch: ch.upper() for ch in raw.ch_names}
+ raw.rename_channels(upper_mapping)
+ raw.pick(all_channels)
+
+ # Strip EEG part from mapping
+ mapping = {ch: ch.split(' ')[1] for ch in raw.ch_names}
+ raw.rename_channels(mapping)
+
+ # Filtering (bandpass and notch)
+ raw.filter(l_freq=0.1, h_freq=75.0, verbose="ERROR")
+ raw.notch_filter(50, verbose="ERROR")
+
+ # Downsample
+ if raw.info['sfreq'] != sampling_freq:
+ raw.resample(sampling_freq)
+ data = raw.get_data(units='uV')
+
+ # Check for NaN and Infs (just for notification)
+ if np.isnan(data).any() or np.isinf(data).any():
+ print(f"Warning: NaN or Inf values found in file {file_path.name}")
+
+ n_channels, n_times = data.shape
+ print(f"Data shape (channels x timepoints): {data.shape}")
+ interval_size = segment_length * sampling_freq
+ num_intervals = n_times // interval_size
+
+ new_sliced_data = data[:, :num_intervals * interval_size].reshape(num_intervals, n_channels, interval_size)
+
+ create_hdf5(new_sliced_data, output_dir, session_name)
+
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input_dir', type=str, default='#CHANGEME', help='Input directory containing raw EDF files')
+ parser.add_argument('--output_dir', type=str, default='#CHANGEME')
+
+ # Parse arguments
+ args = parser.parse_args()
+ input_dir = args.input_dir
+ output_dir = args.output_dir
+
+ # Find all EDF files in the input directory
+ input_dir = Path(input_dir)
+ file_paths = [f for f in input_dir.rglob('*.edf')]
+ print(f"Found {len(file_paths)} EDF files.")
+
+ process_and_save_files_to_hdf5(file_paths, output_dir)
+
+
+
diff --git a/make_wesad_dataset.py b/make_wesad_dataset.py
new file mode 100644
index 0000000..85a1be0
--- /dev/null
+++ b/make_wesad_dataset.py
@@ -0,0 +1,274 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+
+import pickle
+import argparse
+import h5py
+import os
+
+import numpy as np
+
+from process_raw_ecg import preprocess_signal
+from tqdm import tqdm
+
+def binary_classification(ecg_process, ppg_process, labels, subject, split_type):
+ """
+ Label IDs:
+ - 0 = not defined/transient
+ - 1 = baseline
+ - 2 = stress
+ - 3 = amusement
+ - 4 = meditation
+ - 5/6/7 = should be ignored in this dataset
+
+ Binary classification: Combined baseline and amusement sessions form non-stress class, while the stress class is original stress session.
+ """
+ # Each segment is 1min long
+ subseq_size_label = 700 * 60
+ subseq_size_data = 256 * 60
+
+ # Extact unique labels and their occurence points
+ uniques, uniques_index = np.unique(labels, return_index=True)
+ counter_subject = 0
+ for unique, unique_startidx in zip(uniques, uniques_index):
+ if unique not in [1, 2, 3]:
+ continue
+
+ while True:
+ if unique_startidx // subseq_size_label * subseq_size_data > ppg_process.shape[1]:
+ break
+ for next_idx in range(unique_startidx, len(labels)):
+ if unique != labels[next_idx]:
+ break
+
+ totalsubseqs = (next_idx - unique_startidx) // subseq_size_label
+ startidx = unique_startidx // subseq_size_label * subseq_size_data
+ data_temp_60sec_ppg = ppg_process[:, startidx:startidx + totalsubseqs * subseq_size_data]
+ data_temp_60sec_ecg = ecg_process[:, startidx:startidx + totalsubseqs * subseq_size_data]
+
+ if unique_startidx // subseq_size_label * subseq_size_data + totalsubseqs * subseq_size_data > ppg_process.shape[1]:
+ totalsubseqs = data_temp_60sec_ppg.shape[0] // subseq_size_data
+ data_temp_60sec_ppg = data_temp_60sec_ppg[:totalsubseqs * subseq_size_data, :]
+ data_temp_60sec_ecg = data_temp_60sec_ecg[:totalsubseqs * subseq_size_data, :]
+
+ if totalsubseqs == 0:
+ break
+
+ data_temp_60sec_ppg = np.stack(np.split(data_temp_60sec_ppg, totalsubseqs, 1), 0)
+ data_temp_60sec_ecg = np.stack(np.split(data_temp_60sec_ecg, totalsubseqs, 1), 0)
+
+ # Concat both signals
+ total_concat = np.concatenate([data_temp_60sec_ecg, data_temp_60sec_ppg], axis=1)
+ unique_temp = 1 if unique == 3 else unique
+ label_temp_60sec = np.repeat(unique_temp, totalsubseqs)
+
+ print(f"{subject}:", total_concat.shape, label_temp_60sec)
+
+ # Write each sample piece to pickle
+ for idx in range(len(label_temp_60sec)):
+
+ data_dict = {"X": total_concat[idx], "y": label_temp_60sec[idx]}
+ dump_path = os.path.join(output_dir, split_type, f"wesad_{subject}_{counter_subject}.pkl")
+ counter_subject += 1
+
+ with open(dump_path, 'wb') as f:
+ pickle.dump(data_dict, f)
+
+ if unique != 4:
+ break
+
+def multiclass_classification(ecg_process, ppg_process, labels, subject, split_type):
+ """
+ Multiclass classification: Stress, Baseline, Amusement and Meditation.
+ """
+ # Each segment is 1min long
+ subseq_size_label = 700 * 60
+ subseq_size_data = 256 * 60
+
+ # Extact unique labels and their occurence points
+ uniques, uniques_index = np.unique(labels, return_index=True)
+ counter_subject = 0
+ for unique, unique_startidx in zip(uniques, uniques_index):
+
+ flag = False
+ if unique not in [1, 2, 3, 4]:
+ continue
+
+ while True:
+ if unique_startidx // subseq_size_label * subseq_size_data > ppg_process.shape[1]:
+ break
+ for next_idx in range(unique_startidx, len(labels)):
+ if unique != labels[next_idx]:
+ break
+
+ totalsubseqs = (next_idx - unique_startidx) // subseq_size_label
+ startidx = unique_startidx // subseq_size_label * subseq_size_data
+ data_temp_60sec_ppg = ppg_process[:, startidx:startidx + totalsubseqs * subseq_size_data]
+ data_temp_60sec_ecg = ecg_process[:, startidx:startidx + totalsubseqs * subseq_size_data]
+
+ if unique_startidx // subseq_size_label * subseq_size_data + totalsubseqs * subseq_size_data > ppg_process.shape[1]:
+ totalsubseqs = data_temp_60sec_ppg.shape[0] // subseq_size_data
+ data_temp_60sec_ppg = data_temp_60sec_ppg[:totalsubseqs * subseq_size_data, :]
+ data_temp_60sec_ecg = data_temp_60sec_ecg[:totalsubseqs * subseq_size_data, :]
+
+ if totalsubseqs == 0:
+ break
+
+ data_temp_60sec_ppg = np.stack(np.split(data_temp_60sec_ppg, totalsubseqs, 1), 0)
+ data_temp_60sec_ecg = np.stack(np.split(data_temp_60sec_ecg, totalsubseqs, 1), 0)
+
+ # Concat both signals
+ total_concat = np.concatenate([data_temp_60sec_ecg, data_temp_60sec_ppg], axis=1)
+ label_temp_60sec = np.repeat(unique, totalsubseqs)
+
+ print(f"{subject}:", total_concat.shape, label_temp_60sec)
+ # Write each sample piece to pickle
+ for idx in range(len(label_temp_60sec)):
+
+ data_dict = {"X": total_concat[idx], "y": label_temp_60sec[idx]}
+ dump_path = os.path.join(output_dir, split_type, f"wesad_{subject}_{counter_subject}.pkl")
+ counter_subject += 1
+
+ with open(dump_path, 'wb') as f:
+ pickle.dump(data_dict, f)
+
+ if unique != 4:
+ break
+ else:
+ if flag:
+ break
+ flag = True
+ new_label = labels[next_idx:]
+ uniques_temp, uniques_indedata_temp = np.unique(new_label, return_index=True)
+ try:
+ unique_startidx = uniques_indedata_temp[np.where(uniques_temp == 4)][0] + next_idx
+ except IndexError:
+ break
+
+def process_split(input_dir, output_dir, subjects, split_type, classification_type):
+ """
+ Process PPG and ECG data, depending on the classification type of the task.
+ """
+ # Create directory for the split
+ os.makedirs(os.path.join(output_dir, split_type), exist_ok=True)
+
+ # Iterate over all subjects
+ for subject in subjects:
+ file_path = os.path.join(input_dir, subject, f"{subject}.pkl")
+
+ with open(file_path, 'rb') as f:
+ data = pickle.load(f, encoding='latin1')
+
+ # Extract the channels and labels
+ ecg = data['signal']['chest']['ECG'].T
+ ppg = data['signal']['wrist']['BVP'].T
+ labels = data['label']
+
+ # Preprocess the signal
+ ecg_process = preprocess_signal(ecg, fs=700, low=0.5, high=120.0, downsample_fs=256, upsample_fs=None)
+ ppg_process = preprocess_signal(ppg, fs=64, low=0.5, high=8.0, downsample_fs=None, upsample_fs=256)
+
+ # Depending on the classification_type prepare the data
+ if classification_type == "binary":
+ binary_classification(ecg_process, ppg_process, labels, subject, split_type)
+ else:
+ multiclass_classification(ecg_process, ppg_process, labels, subject, split_type)
+
+def create_hdf5(source_dir, target_file, finetune=True, group_size=1000):
+ """
+ Lists all the pickle files in the source directory and writes them in the .h5 file.
+
+ Args:
+ source_dir (str): Source directory where pickle files are found.
+ target_file (str): Path to the target .h5 file.
+ finetune (bool): Whether data has label or not.
+ group_size (int): Group size in HDF5 saving.
+ """
+ # List all the files in the folder
+ files = sorted(os.listdir(source_dir))
+ data_group = []
+
+
+ with h5py.File(target_file, 'w') as h5f:
+ for i, file in enumerate(tqdm(files)):
+ with open(os.path.join(source_dir, file), 'rb') as f:
+
+ sample = pickle.load(f)
+ data_group.append(sample)
+
+ if (i + 1) % group_size == 0 or i == len(files) - 1:
+
+ grp = h5f.create_group(f"data_group_{i // group_size}")
+ X_data = np.array([s['X'] for s in data_group])
+ grp.create_dataset("X", data=X_data)
+
+ if(finetune):
+ y_data = np.array([s['y'] for s in data_group])
+ print(y_data)
+ grp.create_dataset("y", data=y_data)
+
+ data_group = []
+
+if __name__ == "__main__":
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--input_dir', type=str, default='#CHANGEME', help='Input directory containing raw files.')
+ parser.add_argument('--output_dir', type=str, default='#CHANGEME')
+
+ # Parse arguments
+ args = parser.parse_args()
+ input_dir = args.input_dir
+ output_dir = args.output_dir
+
+ # Extract all the folders related to the subjects
+ folders = os.listdir(input_dir)
+ folders.sort()
+
+ # We need only ones that have S in the name
+ subjects = []
+ for name in folders:
+ if name[0] != "S":
+ continue
+ subjects.append(name)
+
+ # Create output directory if it doesn't exist
+ os.makedirs(output_dir, exist_ok=True)
+
+ # Taking pre-determined train/val/test splits from Pulse-PPG (binary classification)
+ process_split(input_dir, output_dir, subjects[:11], split_type='train_binary', classification_type='binary')
+ process_split(input_dir, output_dir, subjects[11:13], split_type='val_binary', classification_type='binary')
+ process_split(input_dir, output_dir, subjects[13:15], split_type='test_binary', classification_type='binary')
+
+ process_split(input_dir, output_dir, subjects[:11], split_type='train_multiclass', classification_type='multiclass')
+ process_split(input_dir, output_dir, subjects[11:13], split_type='val_multiclass', classification_type='multiclass')
+ process_split(input_dir, output_dir, subjects[13:15], split_type='test_multiclass', classification_type='multiclass')
+
+ # Finally, write to HDF5
+ to_do = ['train_binary', 'val_binary', 'test_binary', 'train_multiclass', 'val_multiclass', 'test_multiclass']
+ for td in to_do:
+ if os.path.exists(output_dir + '/' + td + '.h5'):
+ print(f"File {td}.h5 already exists!")
+ else:
+ print(f"Creating file {td}.h5.")
+ create_hdf5(output_dir + "/" + td, output_dir + "/" + td + ".h5")
+
+
+
+
\ No newline at end of file
diff --git a/multiloader_data_module_PanLUNA.yaml b/multiloader_data_module_PanLUNA.yaml
new file mode 100644
index 0000000..2595770
--- /dev/null
+++ b/multiloader_data_module_PanLUNA.yaml
@@ -0,0 +1,720 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+# @package _global_
+defaults:
+ - dataset_types
+
+data_module:
+ _target_: 'data_module.multiloader_data_module.VaryingChannelsDataModule'
+ cfg:
+ num_workers: ${num_workers}
+ batch_size: ${batch_size}
+ train_val_split_ratio: 0.8
+ datasets:
+ TUEG_20_channels_0:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 20
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_20_channels_1:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 20
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_20_channels_2:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 20
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_0:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_11:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_12:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_13:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_14:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_16:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_18:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_19:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_1:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_20:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_23:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_24:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_31:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_3:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_4:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_5:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_15:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_21:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_22:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_25:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_26:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_27:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_28:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_20_channels_3:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 20
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_29:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_2:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_30:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_6:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_7:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_8:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_9:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ TUEG_22_channels_10:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 22
+ TUEG_22_channels_17:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "${env:SCRATCH_PATH}/TUEG/TUEG_22_channels_17.h5"
+ num_channels: 22
+ channels: ${dataset_types.tueg.channels}
+ location_fn: ${dataset_types.tueg.location_fn}
+ sensor_type: ${dataset_types.tueg.sensor_type}
+ VitalDB_2_channels_0:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ VitalDB_2_channels_1:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ VitalDB_2_channels_2:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ VitalDB_2_channels_3:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ VitalDB_2_channels_4:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ VitalDB_2_channels_5:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ VitalDB_2_channels_6:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ VitalDB_2_channels_7:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ VitalDB_2_channels_8:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ VitalDB_2_channels_9:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ MimicDB_2_channels_0:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ MimicDB_2_channels_1:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ MimicDB_2_channels_2:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ MimicDB_2_channels_3:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ MimicDB_2_channels_4:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ MimicDB_2_channels_5:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ MimicDB_2_channels_6:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 2
+ channels: ${dataset_types.pulsedb.channels}
+ location_fn: ${dataset_types.pulsedb.location_fn}
+ sensor_type: ${dataset_types.pulsedb.sensor_type}
+ MIMIC-IV-ECG_12_channels_0:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_1:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_2:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_3:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_4:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_5:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_6:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_7:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_8:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_9:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_10:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_11:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_12:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_13:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_14:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_15:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_16:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_17:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_18:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_19:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_20:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_21:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_22:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_23:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_24:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_25:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ MIMIC-IV-ECG_12_channels_26:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.mimicIV.channels}
+ location_fn: ${dataset_types.mimicIV.location_fn}
+ sensor_type: ${dataset_types.mimicIV.sensor_type}
+ CODE15_0:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_1:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_2:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_3:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_4:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_5:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_6:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_7:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_8:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_9:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_10:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_11:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_12:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_13:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_14:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_15:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_16:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ CODE15_17:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 12
+ channels: ${dataset_types.code15.channels}
+ location_fn: ${dataset_types.code15.location_fn}
+ sensor_type: ${dataset_types.code15.sensor_type}
+ SIENA_29_channels:
+ _target_: 'datasets.pretraining_datasets_PanLUNA.BioSignal_Dataset'
+ hdf5_file: "#CHANGEME"
+ num_channels: 29
+ channels: ${dataset_types.siena.channels}
+ location_fn: ${dataset_types.siena.location_fn}
+ sensor_type: ${dataset_types.siena.sensor_type}
\ No newline at end of file
diff --git a/pretrain_task_PanLUNA.py b/pretrain_task_PanLUNA.py
new file mode 100644
index 0000000..5d2944e
--- /dev/null
+++ b/pretrain_task_PanLUNA.py
@@ -0,0 +1,273 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+
+import torch
+import pytorch_lightning as pl
+import hydra
+import wandb
+import torch_optimizer as torch_optim
+import torch.nn.functional as F
+import matplotlib.pyplot as plt
+import numpy as np
+from criterion.query_specialization_criterion import QuerySpecializationCriterion
+
+class ChannelWiseNormalize:
+ def __init__(self, eps=1e-8):
+ self.eps = eps
+
+ def __call__(self, tensor):
+ with torch.no_grad():
+ # tensor: (B, C, T)
+ mean = tensor.mean(dim=2, keepdim=True)
+ std = tensor.std(dim=2, keepdim=True)
+ return (tensor - mean) / (std + self.eps)
+
+class MaskTask(pl.LightningModule):
+ """
+ PyTorch Lightning module for training a model with masked reconstruction.
+
+ Args:
+ hparams (DictConfig): Parameters and configurations loaded via Hydra.
+ """
+ def __init__(self, hparams):
+ super().__init__()
+ self.save_hyperparameters(hparams)
+ self.model = hydra.utils.instantiate(self.hparams.model)
+ self.criterion = hydra.utils.instantiate(self.hparams.criterion)
+ self.query_specialization_criterion = QuerySpecializationCriterion(**self.hparams.query_specialization_criterion)
+ self.patch_size = self.hparams.masking.patch_size
+ self.masking_ratio = self.hparams.masking.masking_ratio
+ self.unmasked_loss_coeff = self.hparams.masking.unmasked_loss_coeff
+ # Enable normalization if specified in parameters
+ if self.hparams.input_normalization is not None and self.hparams.input_normalization.normalize:
+ self.normalize = True
+ self.normalize_fct = ChannelWiseNormalize()
+ else:
+ self.normalize = False
+
+ self.plot_batches_flags = {'12': True, '22': True, '2': True, '20': False, '29': True, '64': True}
+
+ def generate_mask(self, batch_size, C, T):
+ """
+ Generate a boolean mask for block-wise rectangular masking.
+
+ Args:
+ batch_size (int): Batch size.
+ C (int): Number of channels (height).
+ T (int): Temporal length (width).
+
+ Returns:
+ torch.BoolTensor: Boolean mask of shape (batch_size, C, T),
+ with True in the masked regions.
+ """
+ patch_H, patch_W = self.patch_size
+ masking_ratio = self.masking_ratio
+
+ # Calculate total number of patch rectangles
+ num_rectangles = (C // patch_H) * (T // patch_W)
+ num_to_mask = int(num_rectangles * masking_ratio)
+
+ row_indices = torch.arange(0, C, patch_H)
+ col_indices = torch.arange(0, T, patch_W)
+ rectangles = [(i, j) for i in row_indices for j in col_indices]
+
+ # Randomly select which rectangles to mask
+ selected_indices = torch.randperm(num_rectangles)[:num_to_mask]
+
+ mask = torch.zeros(batch_size, C, T, dtype=torch.bool).to(self.device)
+
+ # Set mask to True in the selected regions
+ for idx in selected_indices:
+ r, c = rectangles[idx]
+ mask[:, r:r + patch_H, c:c + patch_W] = True
+
+ return mask
+
+ def training_step(self, batch, batch_idx):
+ """
+ Training step: apply mask, normalize and compute loss.
+
+ Args:
+ batch (torch.Tensor): Input batch.
+ batch_idx (int): Batch index.
+
+ Returns:
+ torch.Tensor: Loss value.
+ """
+ X = batch["input"]
+ channel_locations = batch["channel_locations"]
+ channel_names = batch.get("channel_names", None)
+ sensor_type = batch["sensor_type"]
+ mask = self.generate_mask(X.shape[0], X.shape[1], X.shape[2])
+
+ if self.normalize:
+ X = self.normalize_fct(X)
+
+ # Pass masked input through the model to get reconstruction and embeddings
+ x_reconstructed, x_original, attention_scores = self.model(X, mask, channel_locations, sensor_type, channel_names)
+ if torch.isnan(X).any() or torch.isinf(X).any():
+ print("!!! Input X_ORIGINAL contains NaN or Inf at step 0!!!")
+ # Print the data source or indices for debugging
+ raise RuntimeError("Input data is corrupted.")
+ if torch.isnan(x_reconstructed).any():
+ print("!!! Model output X_RECONSTRUCTED contains NaN at step", self.global_step, "!!!")
+ raise ValueError("NaN detected in model output.")
+ # Compute loss only on masked parts
+ masked_loss, unmasked_loss = self.criterion(x_reconstructed, x_original, mask, self.patch_size[1])
+ loss = masked_loss + self.unmasked_loss_coeff * unmasked_loss
+ if self.hparams.query_specialization_criterion is not None:
+ query_specialization_loss = self.query_specialization_criterion(attention_scores)
+ loss += query_specialization_loss
+ self.log('query_specialization_loss', query_specialization_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+
+ self.log('train_loss', masked_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+ return loss
+
+ def validation_step(self, batch, batch_idx):
+ """
+ Validation step: apply mask, normalize, compute loss and log signals.
+
+ Args:
+ batch (torch.Tensor): Input batch.
+ batch_idx (int): Batch index.
+
+ Returns:
+ torch.Tensor: Loss value.
+ """
+ X = batch["input"]
+ channel_locations = batch["channel_locations"]
+ channel_names = batch.get("channel_names", None)
+ sensor_type = batch["sensor_type"]
+ mask = self.generate_mask(X.shape[0], X.shape[1], X.shape[2])
+
+ if self.normalize:
+ X = self.normalize_fct(X)
+
+ x_reconstructed, x_original, attention_scores = self.model(X, mask, channel_locations, channel_names)
+
+ masked_loss, unmasked_loss = self.criterion(x_reconstructed, x_original, mask, self.patch_size[1])
+ loss = masked_loss + self.unmasked_loss_coeff * unmasked_loss
+
+ if self.hparams.query_specialization_criterion is not None:
+ query_specialization_loss = self.query_specialization_criterion(attention_scores)
+ loss += query_specialization_loss
+ self.log('query_specialization_loss', query_specialization_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True, sync_dist=True)
+
+ self.log('val_loss', loss, prog_bar=True, on_step=True, on_epoch=True, logger=True, sync_dist=True)
+
+ # Fixed indices for logging signals
+ random_indices = [6, 16, 30]
+ dataset_id = str(sensor_type.shape[1])
+ # Log signals with mask only for the first validation batch
+ if self.plot_batches_flags[dataset_id]:
+ self.log_signals_with_mask(
+ x_original.float(),
+ x_reconstructed.float(),
+ f"Reconstruction {dataset_id}",
+ mask,
+ batch_indices=random_indices,
+ indice_batch=batch_idx
+ )
+ self.plot_batches_flags[dataset_id] = False
+
+ return loss
+
+ def on_validation_epoch_end(self):
+ # Restart batches for plotting - assumes no shuffling
+ self.plot_batches_flags = {'12': True, '22': True, '2': True, '20': False, '29': True, '64': True}
+
+ def configure_optimizers(self):
+ """
+ Configure optimizer and scheduler based on parameters.
+
+ Returns:
+ dict: Dictionary with optimizer and scheduler for PyTorch Lightning.
+ """
+ if self.hparams.optimizer.optim == "SGD":
+ optimizer = torch.optim.SGD(self.model.parameters(), lr=self.hparams.optimizer.lr, momentum=0.9)
+ elif self.hparams.optimizer.optim == 'Adam':
+ optimizer = torch.optim.Adam(self.model.parameters(), lr=self.hparams.optimizer.lr, weight_decay=0.01)
+ elif self.hparams.optimizer.optim == 'AdamW':
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.hparams.optimizer.lr)
+ elif self.hparams.optimizer.optim == 'LAMB':
+ optimizer = torch_optim.Lamb(self.model.parameters(), lr=self.hparams.optimizer.lr)
+ else:
+ raise NotImplementedError("No valid optim name")
+
+ scheduler = hydra.utils.instantiate(self.hparams.scheduler, optimizer, total_training_opt_steps=self.trainer.estimated_stepping_batches)
+ lr_scheduler_config = {
+ "scheduler": scheduler,
+ "interval": "step",
+ "frequency": 1,
+ }
+
+ return {"optimizer": optimizer, "lr_scheduler": lr_scheduler_config}
+
+ def lr_scheduler_step(self, scheduler, metric):
+ scheduler.step_update(num_updates=self.global_step)
+
+ def log_signals_with_mask(self, original, reconstructed, title, mask=None, batch_indices=None, indice_batch=None):
+ """
+ Log original and reconstructed signals highlighting masked regions.
+
+ Args:
+ original (torch.Tensor): Original signals.
+ reconstructed (torch.Tensor): Signals reconstructed by the model.
+ mask (torch.BoolTensor, optional): Applied mask.
+ batch_indices (list[int], optional): Batch indices to log.
+ indice_batch (int, optional): Current batch index.
+ """
+ patch_H, patch_W = self.patch_size
+ batch_size, C, T = original.shape
+
+ for batch_idx in batch_indices:
+ original_signal = original[batch_idx]
+ reconstructed_signal = reconstructed[batch_idx]
+
+ fig, ax = plt.subplots(1, 1, figsize=(15, 6))
+
+ # Limit visualization to the first patch_H channels
+ original_signal_c2 = original_signal[1:2, :]
+ reconstructed_signal_c2 = reconstructed_signal[1:2, :]
+
+ ax.plot(original_signal_c2[0].cpu().numpy(), label='Original Channel 0', color='blue', alpha=0.7)
+ ax.plot(reconstructed_signal_c2[0].cpu().numpy(), label='Reconstructed Channel 0', color='orange', alpha=0.7)
+
+ if mask is not None:
+ mask_c2 = mask[batch_idx, 1:2, :]
+ indices = []
+
+ # Highlight masked regions with a light gray transparent band
+ for i in range(patch_H):
+ for j in range(T // patch_W):
+ if mask_c2[i, j * patch_W:(j + 1) * patch_W].all():
+ ax.axvspan(j * patch_W, (j + 1) * patch_W, color='lightgray', alpha=0.1)
+ indices.append(j)
+
+ # Remove duplicates and sort highlighted indices
+ indices_array = np.array(indices)
+ indices_array = np.unique(indices)
+
+ ax.set_title(f"Signal Reconstruction - batch_ {batch_idx}")
+ ax.legend()
+
+ # Log the figure on TensorBoard with batch and index in the title
+ self.logger.experiment.add_figure(f'{title}_batch {batch_idx}', fig, self.current_epoch)
+ plt.close(fig)
diff --git a/pretrain_task_PanLUNA.yaml b/pretrain_task_PanLUNA.yaml
new file mode 100644
index 0000000..cf36213
--- /dev/null
+++ b/pretrain_task_PanLUNA.yaml
@@ -0,0 +1,22 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+# @package _global_
+task:
+ _target_: 'tasks.pretrain_task_PanLUNA.MaskTask'
\ No newline at end of file
diff --git a/pretraining_datasets_PanLUNA.py b/pretraining_datasets_PanLUNA.py
new file mode 100644
index 0000000..cf958c4
--- /dev/null
+++ b/pretraining_datasets_PanLUNA.py
@@ -0,0 +1,96 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+import torch
+import h5py
+import numpy as np
+from models.modules.lead_positions import (
+ get_channel_indices,
+ get_channel_locations,
+ map_lead_labels_to_angles,
+)
+
+class BioSignal_Dataset(torch.utils.data.Dataset):
+ """
+ Unified dataset class for pretraining data.
+ Supports EEG (TUEG, Siena) and ECG/ECG&PPG (Code15, MIMIC-IV, PulseDB) HDF5 files.
+
+ Args:
+ hdf5_file: Path to the .h5 file.
+ channels: List of channels.
+ location_fn: "eeg" or "ecg".
+ sensor_type: 0 (ECG), 1(EEG), 2(PPG).
+ num_channels: Number of channels taken from total channel list.
+ """
+ def __init__(
+ self,
+ hdf5_file: str,
+ channels: list[str],
+ location_fn: str,
+ sensor_type: int | list[int],
+ num_channels: int | None = None,
+ ):
+ super().__init__()
+ channel_names = channels[:num_channels] if num_channels else channels
+ self.num_channels = len(channel_names)
+
+ if location_fn == "eeg":
+ locs = np.stack(get_channel_locations(channel_names), axis=0)
+ self.channel_locations = torch.from_numpy(locs).float()
+ else:
+ self.channel_locations = torch.FloatTensor(map_lead_labels_to_angles(channel_names))
+
+ self.channel_indices = torch.tensor(get_channel_indices(channel_names), dtype=torch.long)
+
+ if isinstance(sensor_type, list):
+ self.sensor_type = torch.tensor(sensor_type, dtype=torch.long)
+ else:
+ self.sensor_type = torch.full((self.num_channels, ), sensor_type, dtype=torch.long)
+
+ self.data = h5py.File(hdf5_file, "r")
+ self.keys = list(self.data.keys())
+ self.index_map = []
+
+ for key in self.keys:
+ if key == 'index_map':
+ continue
+ group_size = len(self.data[key]["X"])
+ self.index_map.extend([(key, i) for i in range(group_size)])
+
+ def __len__(self):
+ return len(self.index_map)
+
+ def __getitem__(self, index):
+
+ group_key, sample_idx = self.index_map[index]
+ grp = self.data[group_key]
+ X = torch.FloatTensor(grp["X"][sample_idx])
+
+ item = {
+ "input": X,
+ "channel_names": self.channel_indices,
+ "channel_locations": self.channel_locations,
+ "sensor_type": self.sensor_type
+ }
+
+ return item
+
+ def __del__(self):
+ if hasattr(self, "data"):
+ self.data.close()
\ No newline at end of file
diff --git a/process_raw_ecg.py b/process_raw_ecg.py
new file mode 100644
index 0000000..7939181
--- /dev/null
+++ b/process_raw_ecg.py
@@ -0,0 +1,121 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+import wfdb
+import numpy as np
+import pandas as pd
+import pickle
+import argparse
+import tqdm
+import ast
+import os
+import math
+
+from scipy.signal import butter, filtfilt, iirnotch, resample_poly
+from mat73 import loadmat
+
+def preprocess_signal(waveform, fs, low, high, downsample_fs=None, upsample_fs=None):
+ """
+ Preprocess ECG&PPG waveform.
+
+ wavefrom (np.array): Numpy array of shape (12x5000).
+ band (tuple): Tuple containing band for bandpass filtering.
+ fs (int): Sampling frequency.
+ low (int): Low cut frequency for bandpass filter.
+ high (int): High cut frequency for bandpass filter.
+ downsample_fs (int): Downsampling frequency. If None, no downsampling is applied.
+ upsample_fs (int): Upsampling frequency. If None, no upsampling is applied
+
+ """
+ # -------- Missing values -----------------
+ # At least checking for them
+ # For now PTB-XL doesn't have it
+ check_nan_inf = (~np.isfinite(waveform)).sum()
+ if check_nan_inf != 0:
+ # Here we will put some kind of processing at one moment
+ waveform = np.nan_to_num(waveform, nan=0.0, posinf=0.0, neginf=0.0)
+
+ # ------- Upsampling -----------------
+ if upsample_fs is not None:
+ waveform = resample_poly(waveform, up=upsample_fs, down=fs, axis=-1)
+ fs = upsample_fs
+
+ # -------- Bandpass filtering ------------
+ b, a = butter(4, [low/(fs/2), high/(fs/2)], btype="band")
+ waveform_bp = filtfilt(b, a, waveform, axis=-1)
+
+ # -------- Notch filtering -------------
+ notch_freq_1 = 50
+ notch_freq_2 = 60
+ q = 30 # quality factor?
+
+ # First notch filter
+ b_notch, a_notch = iirnotch(notch_freq_1, q, fs)
+ waveform_notch = filtfilt(b_notch, a_notch, waveform_bp, axis=-1)
+
+ # Second notch filter
+ b_notch, a_notch = iirnotch(notch_freq_2, q, fs)
+ waveform_notch = filtfilt(b_notch, a_notch, waveform_notch, axis=-1)
+
+ # -------- Downsampling ----------------
+ if downsample_fs is not None:
+ # This is general down-sampling pipeline working for any frequency
+ # We need to find greatest common divisor of the two to have up and down factors
+ gcd_value = math.gcd(fs, downsample_fs)
+ up_factor = downsample_fs // gcd_value
+ down_factor = fs // gcd_value
+
+ resampled = resample_poly(waveform_notch, up=up_factor, down=down_factor, axis=-1)
+ else:
+ resampled = waveform_notch
+
+ return resampled
+
+def time_segmenting(signal, split_signal, sampling_rate, downsample_fs=None, upsample_fs=None):
+ """
+ Segments the data into splits of length split_signal (in seconds).
+ Assumes time length of signal si dividable with split_signal, i.e. there's no need for padding.
+ downsample_fs and upsample_fs are used to determine effective sampling rate after preprocessing.
+ One of them should be None.
+
+ Args:
+ signal (np.array): Processed signal of shape [num_channels, time_length].
+ split_signal (int): Length of time segment in seconds for splitting the signal.
+ sampling_rate (int): Original sampling rate of the signal.
+ downsample_fs (int): Downsampling frequency. If None, no downsampling is applied.
+ upsample_fs (int): Upsampling frequency. If None, no upsampling is applied.
+ """
+
+ if downsample_fs is not None:
+ split_samples = downsample_fs * split_signal
+ elif upsample_fs is not None:
+ split_samples = upsample_fs * split_signal
+ else:
+ split_samples = sampling_rate * split_signal
+
+ # We assume it's dividable
+ assert signal.shape[-1] % split_samples == 0, \
+ f"Signal length ({signal.shape[-1]} must be exactly divisible by {split_samples})."
+
+ num_sections = signal.shape[-1] // split_samples
+
+ # Splitting
+ pieces = np.split(signal, num_sections, axis=-1)
+
+ return pieces
\ No newline at end of file
From b7d68e797872928969844f21030a0c4f3f69e72b Mon Sep 17 00:00:00 2001
From: Marija Zelic <145926170+masazelic@users.noreply.github.com>
Date: Mon, 27 Apr 2026 09:51:06 +0200
Subject: [PATCH 2/3] Add citation for PanLUNA model in README
---
README.md | 9 +++++++++
1 file changed, 9 insertions(+)
diff --git a/README.md b/README.md
index 4d71334..b9157ab 100644
--- a/README.md
+++ b/README.md
@@ -502,6 +502,15 @@ If you find this work useful, please cite the respective papers:
primaryClass={cs.AI},
url={https://arxiv.org/abs/2603.19100},
}
+@misc{zelic2026panluna,
+ title={PanLUNA: An Efficient and Robust Query-Unified Multimodal Model for Edge Biosignal Intelligence},
+ author={Marija Zelic and Anna Tegon and Yawei Li and Thorir Mar Ingolfsson},
+ year={2026},
+ eprint={2604.04297},
+ archivePrefix={arXiv},
+ primaryClass={cs.AI},
+ url={https://arxiv.org/abs/2604.04297},
+}
```
## License
From 2f23a79a2bd99bb73dec9b24201b156d383c5c90 Mon Sep 17 00:00:00 2001
From: Masa Zelic
Date: Tue, 28 Apr 2026 15:02:50 +0200
Subject: [PATCH 3/3] File org
---
.../data_module/dataset_types.yaml | 0
...netune_data_module_multimodal_PanLUNA.yaml | 0
...finetune_data_module_unimodal_PanLUNA.yaml | 0
.../multiloader_data_module_PanLUNA.yaml | 0
config/experiment/PanLUNA_finetune.yaml | 120 ++++++++++++++++++
config/experiment/PanLUNA_pretrain.yaml | 97 ++++++++++++++
.../model/PanLUNA_finetune.yaml | 0
.../model/PanLUNA_pretrain.yaml | 0
.../task/finetune_task_PanLUNA.yaml | 0
.../task/pretrain_task_PanLUNA.yaml | 0
.../finetuning_multimodal_datasets_PanLUNA.py | 0
.../finetuning_unimodal_datasets_PanLUNA.py | 0
.../pretraining_datasets_PanLUNA.py | 0
PanLUNA.md => docs/model/PanLUNA.md | 0
.../make_code15_dataset.py | 0
.../make_cpsc2018_dataset.py | 0
.../make_csn_dataset.py | 0
.../make_hmc_dataset.py | 0
.../make_mimic_iv_dataset.py | 0
.../make_ptbxl_dataset.py | 0
.../make_pulsedb_dataset.py | 0
.../make_seed_vii_dataset.py | 0
.../make_siena_dataset.py | 0
.../make_wesad_dataset.py | 0
.../process_raw_ecg.py | 0
PanLUNA.py => models/PanLUNA.py | 0
.../modules/lead_positions.py | 0
requirements.txt | 1 +
.../finetune_task_PanLUNA.py | 0
.../pretrain_task_PanLUNA.py | 0
30 files changed, 218 insertions(+)
rename dataset_types.yaml => config/data_module/dataset_types.yaml (100%)
rename finetune_data_module_multimodal_PanLUNA.yaml => config/data_module/finetune_data_module_multimodal_PanLUNA.yaml (100%)
rename finetune_data_module_unimodal_PanLUNA.yaml => config/data_module/finetune_data_module_unimodal_PanLUNA.yaml (100%)
rename multiloader_data_module_PanLUNA.yaml => config/data_module/multiloader_data_module_PanLUNA.yaml (100%)
create mode 100644 config/experiment/PanLUNA_finetune.yaml
create mode 100644 config/experiment/PanLUNA_pretrain.yaml
rename PanLUNA_finetune.yaml => config/model/PanLUNA_finetune.yaml (100%)
rename PanLUNA_pretrain.yaml => config/model/PanLUNA_pretrain.yaml (100%)
rename finetune_task_PanLUNA.yaml => config/task/finetune_task_PanLUNA.yaml (100%)
rename pretrain_task_PanLUNA.yaml => config/task/pretrain_task_PanLUNA.yaml (100%)
rename finetuning_multimodal_datasets_PanLUNA.py => datasets/finetuning_multimodal_datasets_PanLUNA.py (100%)
rename finetuning_unimodal_datasets_PanLUNA.py => datasets/finetuning_unimodal_datasets_PanLUNA.py (100%)
rename pretraining_datasets_PanLUNA.py => datasets/pretraining_datasets_PanLUNA.py (100%)
rename PanLUNA.md => docs/model/PanLUNA.md (100%)
rename make_code15_dataset.py => make_datasets/make_code15_dataset.py (100%)
rename make_cpsc2018_dataset.py => make_datasets/make_cpsc2018_dataset.py (100%)
rename make_csn_dataset.py => make_datasets/make_csn_dataset.py (100%)
rename make_hmc_dataset.py => make_datasets/make_hmc_dataset.py (100%)
rename make_mimic_iv_dataset.py => make_datasets/make_mimic_iv_dataset.py (100%)
rename make_ptbxl_dataset.py => make_datasets/make_ptbxl_dataset.py (100%)
rename make_pulsedb_dataset.py => make_datasets/make_pulsedb_dataset.py (100%)
rename make_seed_vii_dataset.py => make_datasets/make_seed_vii_dataset.py (100%)
rename make_siena_dataset.py => make_datasets/make_siena_dataset.py (100%)
rename make_wesad_dataset.py => make_datasets/make_wesad_dataset.py (100%)
rename process_raw_ecg.py => make_datasets/process_raw_ecg.py (100%)
rename PanLUNA.py => models/PanLUNA.py (100%)
rename lead_positions.py => models/modules/lead_positions.py (100%)
rename finetune_task_PanLUNA.py => tasks/finetune_task_PanLUNA.py (100%)
rename pretrain_task_PanLUNA.py => tasks/pretrain_task_PanLUNA.py (100%)
diff --git a/dataset_types.yaml b/config/data_module/dataset_types.yaml
similarity index 100%
rename from dataset_types.yaml
rename to config/data_module/dataset_types.yaml
diff --git a/finetune_data_module_multimodal_PanLUNA.yaml b/config/data_module/finetune_data_module_multimodal_PanLUNA.yaml
similarity index 100%
rename from finetune_data_module_multimodal_PanLUNA.yaml
rename to config/data_module/finetune_data_module_multimodal_PanLUNA.yaml
diff --git a/finetune_data_module_unimodal_PanLUNA.yaml b/config/data_module/finetune_data_module_unimodal_PanLUNA.yaml
similarity index 100%
rename from finetune_data_module_unimodal_PanLUNA.yaml
rename to config/data_module/finetune_data_module_unimodal_PanLUNA.yaml
diff --git a/multiloader_data_module_PanLUNA.yaml b/config/data_module/multiloader_data_module_PanLUNA.yaml
similarity index 100%
rename from multiloader_data_module_PanLUNA.yaml
rename to config/data_module/multiloader_data_module_PanLUNA.yaml
diff --git a/config/experiment/PanLUNA_finetune.yaml b/config/experiment/PanLUNA_finetune.yaml
new file mode 100644
index 0000000..99e1347
--- /dev/null
+++ b/config/experiment/PanLUNA_finetune.yaml
@@ -0,0 +1,120 @@
+
+# @package _global_
+tag: PanLUNA_finetune
+model_size: tiny
+gpus: 4
+num_nodes: 1
+num_workers: 8
+batch_size: 256
+
+training: True
+final_validate: True
+final_test: True
+finetune_pretrained: True
+resume: False
+find_unused_parameters: True
+
+label_smoothing: 0
+layerwise_lr_decay: 0.75
+scheduler_type: cosine
+# Change for different type of task
+# bc - binary classification
+# mlp - multilabel classification
+# mcc - multiclass classification
+classification_type: "#CHANGEME"
+
+pretrained_checkpoint_path: "#CHANGEME"
+
+callbacks:
+ progress_bar:
+ _target_: 'pytorch_lightning.callbacks.TQDMProgressBar'
+ refresh_rate: 50
+ early_stopping:
+ _target_: 'pytorch_lightning.callbacks.EarlyStopping'
+ monitor: 'val_loss'
+ patience: 12
+ mode: 'min'
+ verbose: True
+
+input_normalization:
+ normalize: True
+
+finetuning:
+ mode: "lora" # options: "full", "freeze_encoder", "lora"
+ lora:
+ r: 16
+ alpha: 32
+ dropout: 0.10
+ target_modules:
+ - 'blocks.0.attn.qkv_proj'
+ - 'blocks.1.attn.qkv_proj'
+ - 'blocks.2.attn.qkv_proj'
+ - 'blocks.3.attn.qkv_proj'
+ - 'blocks.4.attn.qkv_proj'
+ - 'blocks.5.attn.qkv_proj'
+ - 'blocks.0.attn.proj'
+ - 'blocks.1.attn.proj'
+ - 'blocks.2.attn.proj'
+ - 'blocks.3.attn.proj'
+ - 'blocks.4.attn.proj'
+ - 'blocks.5.attn.proj'
+ - 'cross_attn.cross_attention'
+ - 'cross_attn.cross_attention.out_proj'
+ - 'cross_attn.query_self_attn.layers.0.self_attn'
+ - 'cross_attn.query_self_attn.layers.1.self_attn'
+ - 'cross_attn.query_self_attn.layers.2.self_attn'
+ - 'cross_attn.query_self_attn.layers.0.self_attn.out_proj'
+ - 'cross_attn.query_self_attn.layers.1.self_attn.out_proj'
+ - 'cross_attn.query_self_attn.layers.2.self_attn.out_proj'
+
+io:
+ checkpoint_dirpath: ${env:CHECKPOINT_DIR}/checkpoints
+ version: ${tag}_${model_size}_finetune
+ base_output_path: "#CHANGEME"
+
+defaults:
+ - override /data_module: finetune_data_module_unimodal_PanLUNA # or finetune_data_module_multimodal_PanLUNA
+ - override /model: PanLUNA_finetune
+ - override /scheduler: cosine
+ - override /task: finetune_task_PanLUNA
+ - override /criterion: finetune_criterion
+
+model:
+ patch_size: 32
+ embed_dim: 64
+ num_heads: 2
+ depth: 6
+ num_queries: 4
+ mlp_ratio: 4
+ drop_path: 0.1
+ num_classes: "#CHANGEME"
+
+trainer:
+ accelerator: gpu
+ num_nodes: ${num_nodes}
+ devices: ${gpus}
+ strategy: ddp
+ max_epochs: 100
+ precision: bf16-mixed
+
+model_checkpoint:
+ save_last: True
+ monitor: "val_loss"
+ mode: "min"
+ save_top_k: 1
+ every_n_epochs: 1
+
+optimizer:
+ optim: 'AdamW'
+ lr: 5.0e-4
+ betas: [0.9, 0.999]
+ weight_decay: 0.05
+
+scheduler:
+ trainer: ${trainer}
+ min_lr: 5.0e-6 # minimum LR for the cosine scheduler
+ warmup_lr_init: 2.5e-7 # initial LR for the warmup phase
+ warmup_epochs: 0
+
+
+
diff --git a/config/experiment/PanLUNA_pretrain.yaml b/config/experiment/PanLUNA_pretrain.yaml
new file mode 100644
index 0000000..2042d20
--- /dev/null
+++ b/config/experiment/PanLUNA_pretrain.yaml
@@ -0,0 +1,97 @@
+#*----------------------------------------------------------------------------*
+#* Copyright (C) 2026 ETH Zurich, Switzerland *
+#* SPDX-License-Identifier: Apache-2.0 *
+#* *
+#* Licensed under the Apache License, Version 2.0 (the "License"); *
+#* you may not use this file except in compliance with the License. *
+#* You may obtain a copy of the License at *
+#* *
+#* http://www.apache.org/licenses/LICENSE-2.0 *
+#* *
+#* Unless required by applicable law or agreed to in writing, software *
+#* distributed under the License is distributed on an "AS IS" BASIS, *
+#* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
+#* See the License for the specific language governing permissions and *
+#* limitations under the License. *
+#* *
+#* Author: Marija Zelic *
+#* Author: Thorir Mar Ingolfsson *
+#*----------------------------------------------------------------------------*
+# @package _global_
+tag: PanLUNA_pretrain
+model_size: tiny
+gpus: 4
+num_nodes: 1
+num_workers: 8
+batch_size: 1024
+
+final_validate: True
+final_test: False
+resume: False
+pretrained_checkpoint_path: null
+freeze_encoder: False
+io:
+ checkpoint_dirpath: "#CHANGEME"
+ checkpoint_filepath: null
+
+defaults:
+ - override /data_module: multiloader_data_module_PanLUNA
+ - override /model: PanLUNA_pretrain
+ - override /scheduler: cosine
+ - override /task: pretrain_task_PanLUNA
+ - override /criterion: pretrain_criterion
+
+query_specialization_criterion:
+ loss_type: 'l2'
+ loss_coeff: 0.8
+
+finetuning:
+ mode: null
+
+model:
+ patch_size: 32
+ embed_dim: 64
+ num_heads: 2
+ depth: 6
+ num_queries: 4
+ mlp_ratio: 4
+ drop_path: 0.0
+ num_classes: 0
+
+masking:
+ patch_size: [1, 32]
+ masking_ratio: 0.5
+ unmasked_loss_coeff: 0.05
+
+input_normalization:
+ normalize: True
+
+scheduler:
+ trainer: ${trainer}
+ min_lr: 2.5e-7 # minimum LR for the cosine scheduler
+ warmup_lr_init: 2.5e-7 # initial LR for the warmup phase
+ warmup_epochs: 10
+
+trainer:
+ accelerator: gpu
+ num_nodes: ${num_nodes}
+ devices: ${gpus}
+ max_epochs: 150
+ strategy: ddp
+ accumulate_grad_batches: 1
+ check_val_every_n_epoch: 1
+ use_distributed_sampler: False
+ precision: bf16-mixed
+ gradient_clip_val: 1.0
+
+model_checkpoint:
+ save_last: True
+ monitor: "val_loss"
+ mode: "min"
+ save_top_k: 1
+
+optimizer:
+ optim: 'AdamW'
+ lr: 1.25e-4
+ betas: [0.9, 0.98]
+ weight_decay: 0.05
\ No newline at end of file
diff --git a/PanLUNA_finetune.yaml b/config/model/PanLUNA_finetune.yaml
similarity index 100%
rename from PanLUNA_finetune.yaml
rename to config/model/PanLUNA_finetune.yaml
diff --git a/PanLUNA_pretrain.yaml b/config/model/PanLUNA_pretrain.yaml
similarity index 100%
rename from PanLUNA_pretrain.yaml
rename to config/model/PanLUNA_pretrain.yaml
diff --git a/finetune_task_PanLUNA.yaml b/config/task/finetune_task_PanLUNA.yaml
similarity index 100%
rename from finetune_task_PanLUNA.yaml
rename to config/task/finetune_task_PanLUNA.yaml
diff --git a/pretrain_task_PanLUNA.yaml b/config/task/pretrain_task_PanLUNA.yaml
similarity index 100%
rename from pretrain_task_PanLUNA.yaml
rename to config/task/pretrain_task_PanLUNA.yaml
diff --git a/finetuning_multimodal_datasets_PanLUNA.py b/datasets/finetuning_multimodal_datasets_PanLUNA.py
similarity index 100%
rename from finetuning_multimodal_datasets_PanLUNA.py
rename to datasets/finetuning_multimodal_datasets_PanLUNA.py
diff --git a/finetuning_unimodal_datasets_PanLUNA.py b/datasets/finetuning_unimodal_datasets_PanLUNA.py
similarity index 100%
rename from finetuning_unimodal_datasets_PanLUNA.py
rename to datasets/finetuning_unimodal_datasets_PanLUNA.py
diff --git a/pretraining_datasets_PanLUNA.py b/datasets/pretraining_datasets_PanLUNA.py
similarity index 100%
rename from pretraining_datasets_PanLUNA.py
rename to datasets/pretraining_datasets_PanLUNA.py
diff --git a/PanLUNA.md b/docs/model/PanLUNA.md
similarity index 100%
rename from PanLUNA.md
rename to docs/model/PanLUNA.md
diff --git a/make_code15_dataset.py b/make_datasets/make_code15_dataset.py
similarity index 100%
rename from make_code15_dataset.py
rename to make_datasets/make_code15_dataset.py
diff --git a/make_cpsc2018_dataset.py b/make_datasets/make_cpsc2018_dataset.py
similarity index 100%
rename from make_cpsc2018_dataset.py
rename to make_datasets/make_cpsc2018_dataset.py
diff --git a/make_csn_dataset.py b/make_datasets/make_csn_dataset.py
similarity index 100%
rename from make_csn_dataset.py
rename to make_datasets/make_csn_dataset.py
diff --git a/make_hmc_dataset.py b/make_datasets/make_hmc_dataset.py
similarity index 100%
rename from make_hmc_dataset.py
rename to make_datasets/make_hmc_dataset.py
diff --git a/make_mimic_iv_dataset.py b/make_datasets/make_mimic_iv_dataset.py
similarity index 100%
rename from make_mimic_iv_dataset.py
rename to make_datasets/make_mimic_iv_dataset.py
diff --git a/make_ptbxl_dataset.py b/make_datasets/make_ptbxl_dataset.py
similarity index 100%
rename from make_ptbxl_dataset.py
rename to make_datasets/make_ptbxl_dataset.py
diff --git a/make_pulsedb_dataset.py b/make_datasets/make_pulsedb_dataset.py
similarity index 100%
rename from make_pulsedb_dataset.py
rename to make_datasets/make_pulsedb_dataset.py
diff --git a/make_seed_vii_dataset.py b/make_datasets/make_seed_vii_dataset.py
similarity index 100%
rename from make_seed_vii_dataset.py
rename to make_datasets/make_seed_vii_dataset.py
diff --git a/make_siena_dataset.py b/make_datasets/make_siena_dataset.py
similarity index 100%
rename from make_siena_dataset.py
rename to make_datasets/make_siena_dataset.py
diff --git a/make_wesad_dataset.py b/make_datasets/make_wesad_dataset.py
similarity index 100%
rename from make_wesad_dataset.py
rename to make_datasets/make_wesad_dataset.py
diff --git a/process_raw_ecg.py b/make_datasets/process_raw_ecg.py
similarity index 100%
rename from process_raw_ecg.py
rename to make_datasets/process_raw_ecg.py
diff --git a/PanLUNA.py b/models/PanLUNA.py
similarity index 100%
rename from PanLUNA.py
rename to models/PanLUNA.py
diff --git a/lead_positions.py b/models/modules/lead_positions.py
similarity index 100%
rename from lead_positions.py
rename to models/modules/lead_positions.py
diff --git a/requirements.txt b/requirements.txt
index 8804a2b..ab79d32 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -10,6 +10,7 @@ tqdm
pandas
mne
nvidia-nccl-cu11
+peft
psutil
pyyaml
tensorboardX
diff --git a/finetune_task_PanLUNA.py b/tasks/finetune_task_PanLUNA.py
similarity index 100%
rename from finetune_task_PanLUNA.py
rename to tasks/finetune_task_PanLUNA.py
diff --git a/pretrain_task_PanLUNA.py b/tasks/pretrain_task_PanLUNA.py
similarity index 100%
rename from pretrain_task_PanLUNA.py
rename to tasks/pretrain_task_PanLUNA.py