Skip to content

Commit 5dab5fb

Browse files
committed
DataModule refactoring (#766)
1 parent cd0bf5b commit 5dab5fb

File tree

13 files changed

+870
-1138
lines changed

13 files changed

+870
-1138
lines changed

pina/_src/condition/condition_base.py

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pina._src.condition.condition_interface import ConditionInterface
1010
from pina._src.core.graph import LabelBatch
1111
from pina._src.core.label_tensor import LabelTensor
12+
from pina._src.data.dummy_dataloader import DummyDataloader
1213

1314

1415
class ConditionBase(ConditionInterface):
@@ -33,6 +34,7 @@ def __init__(self, **kwargs):
3334
"""
3435
super().__init__()
3536
self.data = self.store_data(**kwargs)
37+
self.has_custom_dataloader_fn = False
3638

3739
@property
3840
def problem(self):
@@ -85,7 +87,8 @@ def automatic_batching_collate_fn(cls, batch):
8587
if not batch:
8688
return {}
8789
instance_class = batch[0].__class__
88-
return instance_class.create_batch(batch)
90+
batch = instance_class.create_batch(batch)
91+
return batch
8992

9093
@staticmethod
9194
def collate_fn(batch, condition):
@@ -103,7 +106,11 @@ def collate_fn(batch, condition):
103106
return data
104107

105108
def create_dataloader(
106-
self, dataset, batch_size, shuffle, automatic_batching
109+
self,
110+
dataset,
111+
batch_size,
112+
automatic_batching,
113+
**kwargs,
107114
):
108115
"""
109116
Create a DataLoader for the condition.
@@ -114,14 +121,28 @@ def create_dataloader(
114121
:rtype: torch.utils.data.DataLoader
115122
"""
116123
if batch_size == len(dataset):
117-
pass # will be updated in the near future
124+
return DummyDataloader(dataset)
118125
return DataLoader(
119126
dataset=dataset,
120-
batch_size=batch_size,
121-
shuffle=shuffle,
122127
collate_fn=(
123128
partial(self.collate_fn, condition=self)
124129
if not automatic_batching
125130
else self.automatic_batching_collate_fn
126131
),
132+
batch_size=batch_size,
133+
**kwargs,
127134
)
135+
136+
def switch_dataloader_fn(self, create_dataloader_fn):
137+
"""
138+
Decorator to switch the dataloader function for a condition.
139+
140+
:param create_dataloader_fn: The new dataloader function to use.
141+
:type create_dataloader_fn: function
142+
:return: The decorated function with the new dataloader function.
143+
:rtype: function
144+
"""
145+
# Replace the create_dataloader method of the ConditionBase class with
146+
# the new function
147+
self.has_custom_dataloader_fn = True
148+
self.create_dataloader = create_dataloader_fn

pina/_src/condition/data_manager.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,7 @@ def create_batch(items):
119119
if isinstance(sample, LabelTensor)
120120
else torch.stack
121121
)
122+
batch_data[k] = batch_fn(vals)
122123
batch_data[k] = batch_fn(vals, dim=0)
123124
else:
124125
batch_data[k] = sample

pina/_src/core/trainer.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(
3636
test_size=0.0,
3737
val_size=0.0,
3838
compile=None,
39-
repeat=None,
39+
batching_mode="common_batch_size",
4040
automatic_batching=None,
4141
num_workers=None,
4242
pin_memory=None,
@@ -61,9 +61,9 @@ def __init__(
6161
:param bool compile: If ``True``, the model is compiled before training.
6262
Default is ``False``. For Windows users, it is always disabled. Not
6363
supported for python version greater or equal than 3.14.
64-
:param bool repeat: Whether to repeat the dataset data in each
65-
condition during training. For further details, see the
66-
:class:`~pina.data.data_module.PinaDataModule` class. Default is
64+
:param str batching_mode: The batching mode to use. Options are
65+
``"common_batch_size"``, ``"proportional"``, and
66+
``"separate_conditions"``. Default is ``"common_batch_size"``.
6767
``False``.
6868
:param bool automatic_batching: If ``True``, automatic PyTorch batching
6969
is performed, otherwise the items are retrieved from the dataset
@@ -87,7 +87,7 @@ def __init__(
8787
train_size=train_size,
8888
test_size=test_size,
8989
val_size=val_size,
90-
repeat=repeat,
90+
batching_mode=batching_mode,
9191
automatic_batching=automatic_batching,
9292
compile=compile,
9393
)
@@ -127,24 +127,44 @@ def __init__(
127127
UserWarning,
128128
)
129129

130-
repeat = repeat if repeat is not None else False
131-
132130
automatic_batching = (
133131
automatic_batching if automatic_batching is not None else False
134132
)
135133

134+
if batch_size is None and batching_mode != "common_batch_size":
135+
warnings.warn(
136+
"Batching mode is set to "
137+
f"{batching_mode} but batch_size is None. "
138+
"Batching mode will be set to common_batch_size.",
139+
UserWarning,
140+
)
141+
batching_mode = "common_batch_size"
142+
143+
if (
144+
batch_size is not None
145+
and batch_size <= len(solver.problem.conditions)
146+
and batching_mode == "proportional"
147+
):
148+
warnings.warn(
149+
"Batching mode is set to proportional but batch_size is 1. "
150+
"Batching mode will be set to common_batch_size.",
151+
UserWarning,
152+
)
153+
batching_mode = "common_batch_size"
154+
136155
# set attributes
137156
self.compile = compile
138157
self.solver = solver
139158
self.batch_size = batch_size
140159
self._move_to_device()
141160
self.data_module = None
161+
142162
self._create_datamodule(
143163
train_size=train_size,
144164
test_size=test_size,
145165
val_size=val_size,
146166
batch_size=batch_size,
147-
repeat=repeat,
167+
batching_mode=batching_mode,
148168
automatic_batching=automatic_batching,
149169
pin_memory=pin_memory,
150170
num_workers=num_workers,
@@ -182,7 +202,7 @@ def _create_datamodule(
182202
test_size,
183203
val_size,
184204
batch_size,
185-
repeat,
205+
batching_mode,
186206
automatic_batching,
187207
pin_memory,
188208
num_workers,
@@ -201,8 +221,9 @@ def _create_datamodule(
201221
:param float val_size: The percentage of elements to include in the
202222
validation dataset.
203223
:param int batch_size: The number of samples per batch to load.
204-
:param bool repeat: Whether to repeat the dataset data in each
205-
condition during training.
224+
:param str batching_mode: The batching mode to use. Options are
225+
``"common_batch_size"``, ``"proportional"``, and
226+
``"separate_conditions"``.
206227
:param bool automatic_batching: Whether to perform automatic batching
207228
with PyTorch.
208229
:param bool pin_memory: Whether to use pinned memory for faster data
@@ -232,7 +253,7 @@ def _create_datamodule(
232253
test_size=test_size,
233254
val_size=val_size,
234255
batch_size=batch_size,
235-
repeat=repeat,
256+
batching_mode=batching_mode,
236257
automatic_batching=automatic_batching,
237258
num_workers=num_workers,
238259
pin_memory=pin_memory,
@@ -284,7 +305,7 @@ def _check_input_consistency(
284305
train_size,
285306
test_size,
286307
val_size,
287-
repeat,
308+
batching_mode,
288309
automatic_batching,
289310
compile,
290311
):
@@ -298,8 +319,9 @@ def _check_input_consistency(
298319
test dataset.
299320
:param float val_size: The percentage of elements to include in the
300321
validation dataset.
301-
:param bool repeat: Whether to repeat the dataset data in each
302-
condition during training.
322+
:param str batching_mode: The batching mode to use. Options are
323+
``"common_batch_size"``, ``"proportional"``, and
324+
``"separate_conditions"``.
303325
:param bool automatic_batching: Whether to perform automatic batching
304326
with PyTorch.
305327
:param bool compile: If ``True``, the model is compiled before training.
@@ -309,8 +331,7 @@ def _check_input_consistency(
309331
check_consistency(train_size, float)
310332
check_consistency(test_size, float)
311333
check_consistency(val_size, float)
312-
if repeat is not None:
313-
check_consistency(repeat, bool)
334+
check_consistency(batching_mode, str)
314335
if automatic_batching is not None:
315336
check_consistency(automatic_batching, bool)
316337
if compile is not None:

pina/_src/data/aggregator.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
"""
2+
Aggregator for multiple dataloaders.
3+
"""
4+
5+
6+
class _Aggregator:
7+
"""
8+
The class :class:`_Aggregator` is responsible for aggregating multiple
9+
dataloaders into a single iterable object. It supports different batching
10+
modes to accommodate various training requirements.
11+
"""
12+
13+
def __init__(self, dataloaders, batching_mode):
14+
"""
15+
Initialization of the :class:`_Aggregator` class.
16+
17+
:param dataloaders: A dictionary mapping condition names to their
18+
respective dataloaders.
19+
:type dataloaders: dict[str, DataLoader]
20+
:param batching_mode: The batching mode to use. Options are
21+
``"common_batch_size"``, ``"proportional"``, and
22+
``"separate_conditions"``.
23+
:type batching_mode: str
24+
"""
25+
self.dataloaders = dataloaders
26+
self.batching_mode = batching_mode
27+
28+
def __len__(self):
29+
"""
30+
Return the length of the aggregated dataloader.
31+
32+
:return: The length of the aggregated dataloader.
33+
:rtype: int
34+
"""
35+
if self.batching_mode == "separate_conditions":
36+
return sum(len(dl) for dl in self.dataloaders.values())
37+
return max(len(dl) for dl in self.dataloaders.values())
38+
39+
def __iter__(self):
40+
"""
41+
Return an iterator over the aggregated dataloader.
42+
43+
:return: An iterator over the aggregated dataloader.
44+
:rtype: iterator
45+
"""
46+
if self.batching_mode == "separate_conditions":
47+
# TODO: implement separate_conditions batching mode
48+
raise NotImplementedError(
49+
"Batching mode 'separate_conditions' is not implemented yet."
50+
)
51+
52+
iterators = {name: iter(dl) for name, dl in self.dataloaders.items()}
53+
for _ in range(len(self)):
54+
batch = {}
55+
for name, it in iterators.items():
56+
try:
57+
batch[name] = next(it)
58+
except StopIteration:
59+
iterators[name] = iter(self.dataloaders[name])
60+
batch[name] = next(iterators[name])
61+
yield batch

0 commit comments

Comments
 (0)