-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathpretrain_data_module.py
More file actions
151 lines (129 loc) · 5.9 KB
/
pretrain_data_module.py
File metadata and controls
151 lines (129 loc) · 5.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# *----------------------------------------------------------------------------*
# * Copyright (C) 2025 ETH Zurich, Switzerland *
# * SPDX-License-Identifier: Apache-2.0 *
# * *
# * Licensed under the Apache License, Version 2.0 (the "License"); *
# * you may not use this file except in compliance with the License. *
# * You may obtain a copy of the License at *
# * *
# * http://www.apache.org/licenses/LICENSE-2.0 *
# * *
# * Unless required by applicable law or agreed to in writing, software *
# * distributed under the License is distributed on an "AS IS" BASIS, *
# * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. *
# * See the License for the specific language governing permissions and *
# * limitations under the License. *
# * *
# * Author: Anna Tegon *
# * Author: Thorir Mar Ingolfsson *
# *----------------------------------------------------------------------------*
from typing import Optional
import pytorch_lightning as pl
import torch
from torch.utils.data import ConcatDataset, DataLoader
class PretrainDataModule(pl.LightningDataModule):
"""
PyTorch Lightning DataModule for pretraining that manages multiple datasets,
splitting them into training and validation subsets, and provides train/val
dataloaders with configurable batch size and worker count.
Args:
datasets (dict of Dataset): Dictionary of datasets to be concatenated and split.
test (Dataset, optional): Optional test dataset.
cfg (Config): Configuration object containing batch_size and num_workers.
name (str, optional): Name identifier for this data module.
train_val_split_ratio (float, optional): Ratio for train/validation split (default=0.8).
**kwargs: Additional arguments.
"""
def __init__(
self,
datasets: [torch.utils.data.Dataset],
test=None,
cfg=None,
name="",
train_val_split_ratio=0.8,
**kwargs,
):
super().__init__()
# Filter out None datasets and collect available datasets to concatenate
datasets_list = [
datasets[dataset_name]
for dataset_name in datasets
if datasets[dataset_name] is not None
]
print("datasets list:", datasets_list)
# Initialize lists to hold split datasets
self.train, self.val = [], []
# For each dataset, split into training and validation sets according to the ratio
for dataset in datasets_list:
train_size = int(train_val_split_ratio * len(dataset))
val_size = len(dataset) - train_size
train, val = torch.utils.data.random_split(dataset, [train_size, val_size])
self.train.append(train)
self.val.append(val)
# Concatenate all training splits into a single training dataset
self.train = ConcatDataset(self.train)
# Concatenate all validation splits into a single validation dataset
self.val = ConcatDataset(self.val)
self.test = test
self.name = name
self.cfg = cfg
self.batch_size = self.cfg.batch_size
print(len(self.train), len(self.val))
def setup(self, stage: Optional[str] = None):
"""
Prepare datasets for different stages: 'fit', 'validate', or 'test'.
Args:
stage (str, optional): Stage name. Options: 'fit', 'validate', 'test', or None.
"""
if stage == "fit" or stage is None:
# Assign train and validation datasets for training phase
self.train_dataset = self.train
self.val_dataset = self.val
elif stage == "validate":
# Assign validation dataset for validation phase
self.val_dataset = self.val
elif stage == "test":
# Assign validation dataset for test phase (could be adjusted if test set available)
self.test_dataset = self.val
def train_dataloader(self):
"""
Returns the DataLoader for the training dataset.
Raises:
ValueError: If setup() hasn't been called before this.
Returns:
DataLoader: DataLoader with shuffling enabled for training.
"""
if not hasattr(self, "train_dataset"):
raise ValueError(
"Setup method must be called before accessing train_dataloader."
)
return DataLoader(
self.train_dataset,
num_workers=self.cfg.num_workers,
pin_memory=True,
persistent_workers=True,
shuffle=True,
batch_size=self.batch_size,
drop_last=True,
)
def val_dataloader(self):
"""
Returns the DataLoader for the validation dataset.
Raises:
ValueError: If setup() hasn't been called before this.
Returns:
DataLoader: DataLoader without shuffling for validation.
"""
if not hasattr(self, "val_dataset"):
raise ValueError(
"Setup method must be called before accessing val_dataloader."
)
return DataLoader(
self.val_dataset,
num_workers=self.cfg.num_workers,
pin_memory=True,
persistent_workers=True,
shuffle=False,
batch_size=self.batch_size,
drop_last=True,
)