@@ -36,7 +36,7 @@ def __init__(
3636 test_size = 0.0 ,
3737 val_size = 0.0 ,
3838 compile = None ,
39- repeat = None ,
39+ batching_mode = "common_batch_size" ,
4040 automatic_batching = None ,
4141 num_workers = None ,
4242 pin_memory = None ,
@@ -61,9 +61,9 @@ def __init__(
6161 :param bool compile: If ``True``, the model is compiled before training.
6262 Default is ``False``. For Windows users, it is always disabled. Not
6363 supported for python version greater or equal than 3.14.
64- :param bool repeat: Whether to repeat the dataset data in each
65- condition during training. For further details, see the
66- :class:`~pina.data.data_module.PinaDataModule` class . Default is
64+ :param str batching_mode: The batching mode to use. Options are
65+ ``"common_batch_size"``, ``"proportional"``, and
66+ ``"separate_conditions"`` . Default is ``"common_batch_size"``.
6767 ``False``.
6868 :param bool automatic_batching: If ``True``, automatic PyTorch batching
6969 is performed, otherwise the items are retrieved from the dataset
@@ -87,7 +87,7 @@ def __init__(
8787 train_size = train_size ,
8888 test_size = test_size ,
8989 val_size = val_size ,
90- repeat = repeat ,
90+ batching_mode = batching_mode ,
9191 automatic_batching = automatic_batching ,
9292 compile = compile ,
9393 )
@@ -127,24 +127,44 @@ def __init__(
127127 UserWarning ,
128128 )
129129
130- repeat = repeat if repeat is not None else False
131-
132130 automatic_batching = (
133131 automatic_batching if automatic_batching is not None else False
134132 )
135133
134+ if batch_size is None and batching_mode != "common_batch_size" :
135+ warnings .warn (
136+ "Batching mode is set to "
137+ f"{ batching_mode } but batch_size is None. "
138+ "Batching mode will be set to common_batch_size." ,
139+ UserWarning ,
140+ )
141+ batching_mode = "common_batch_size"
142+
143+ if (
144+ batch_size is not None
145+ and batch_size <= len (solver .problem .conditions )
146+ and batching_mode == "proportional"
147+ ):
148+ warnings .warn (
149+ "Batching mode is set to proportional but batch_size is 1. "
150+ "Batching mode will be set to common_batch_size." ,
151+ UserWarning ,
152+ )
153+ batching_mode = "common_batch_size"
154+
136155 # set attributes
137156 self .compile = compile
138157 self .solver = solver
139158 self .batch_size = batch_size
140159 self ._move_to_device ()
141160 self .data_module = None
161+
142162 self ._create_datamodule (
143163 train_size = train_size ,
144164 test_size = test_size ,
145165 val_size = val_size ,
146166 batch_size = batch_size ,
147- repeat = repeat ,
167+ batching_mode = batching_mode ,
148168 automatic_batching = automatic_batching ,
149169 pin_memory = pin_memory ,
150170 num_workers = num_workers ,
@@ -182,7 +202,7 @@ def _create_datamodule(
182202 test_size ,
183203 val_size ,
184204 batch_size ,
185- repeat ,
205+ batching_mode ,
186206 automatic_batching ,
187207 pin_memory ,
188208 num_workers ,
@@ -201,8 +221,9 @@ def _create_datamodule(
201221 :param float val_size: The percentage of elements to include in the
202222 validation dataset.
203223 :param int batch_size: The number of samples per batch to load.
204- :param bool repeat: Whether to repeat the dataset data in each
205- condition during training.
224+ :param str batching_mode: The batching mode to use. Options are
225+ ``"common_batch_size"``, ``"proportional"``, and
226+ ``"separate_conditions"``.
206227 :param bool automatic_batching: Whether to perform automatic batching
207228 with PyTorch.
208229 :param bool pin_memory: Whether to use pinned memory for faster data
@@ -232,7 +253,7 @@ def _create_datamodule(
232253 test_size = test_size ,
233254 val_size = val_size ,
234255 batch_size = batch_size ,
235- repeat = repeat ,
256+ batching_mode = batching_mode ,
236257 automatic_batching = automatic_batching ,
237258 num_workers = num_workers ,
238259 pin_memory = pin_memory ,
@@ -284,7 +305,7 @@ def _check_input_consistency(
284305 train_size ,
285306 test_size ,
286307 val_size ,
287- repeat ,
308+ batching_mode ,
288309 automatic_batching ,
289310 compile ,
290311 ):
@@ -298,8 +319,9 @@ def _check_input_consistency(
298319 test dataset.
299320 :param float val_size: The percentage of elements to include in the
300321 validation dataset.
301- :param bool repeat: Whether to repeat the dataset data in each
302- condition during training.
322+ :param str batching_mode: The batching mode to use. Options are
323+ ``"common_batch_size"``, ``"proportional"``, and
324+ ``"separate_conditions"``.
303325 :param bool automatic_batching: Whether to perform automatic batching
304326 with PyTorch.
305327 :param bool compile: If ``True``, the model is compiled before training.
@@ -309,8 +331,7 @@ def _check_input_consistency(
309331 check_consistency (train_size , float )
310332 check_consistency (test_size , float )
311333 check_consistency (val_size , float )
312- if repeat is not None :
313- check_consistency (repeat , bool )
334+ check_consistency (batching_mode , str )
314335 if automatic_batching is not None :
315336 check_consistency (automatic_batching , bool )
316337 if compile is not None :
0 commit comments