@@ -34,7 +34,10 @@ def __init__(self, **kwargs):
3434 """
3535 super ().__init__ ()
3636 self .data = self .store_data (** kwargs )
37+ < << << << HEAD
3738 self .has_custom_dataloader_fn = False
39+ == == == =
40+ >> >> >> > 4 ca4993a (Conditions refactoring (#758))
3841
3942 @property
4043 def problem (self ):
@@ -87,8 +90,12 @@ def automatic_batching_collate_fn(cls, batch):
8790 if not batch :
8891 return {}
8992 instance_class = batch [0 ].__class__
93+ < << << << HEAD
9094 batch = instance_class .create_batch (batch )
9195 return batch
96+ == == == =
97+ return instance_class .create_batch (batch )
98+ >> >> >> > 4 ca4993a (Conditions refactoring (#758))
9299
93100 @staticmethod
94101 def collate_fn (batch , condition ):
@@ -106,11 +113,15 @@ def collate_fn(batch, condition):
106113 return data
107114
108115 def create_dataloader (
116+ << << << < HEAD
109117 self ,
110118 dataset ,
111119 batch_size ,
112120 automatic_batching ,
113121 ** kwargs ,
122+ == == == =
123+ self , dataset , batch_size , shuffle , automatic_batching
124+ > >> >> >> 4 ca4993a (Conditions refactoring (#758))
114125 ):
115126 """
116127 Create a DataLoader for the condition.
@@ -121,14 +132,23 @@ def create_dataloader(
121132 :rtype: torch.utils.data.DataLoader
122133 """
123134 if batch_size == len (dataset ):
135+ << < << < < HEAD
124136 return DummyDataloader (dataset )
125137 return DataLoader (
126138 dataset = dataset ,
139+ == == == =
140+ pass # will be updated in the near future
141+ return DataLoader (
142+ dataset = dataset ,
143+ batch_size = batch_size ,
144+ shuffle = shuffle ,
145+ >> >> >> > 4 ca4993a (Conditions refactoring (#758))
127146 collate_fn = (
128147 partial (self .collate_fn , condition = self )
129148 if not automatic_batching
130149 else self .automatic_batching_collate_fn
131150 ),
151+ << << << < HEAD
132152 batch_size = batch_size ,
133153 ** kwargs ,
134154 )
@@ -146,3 +166,6 @@ def switch_dataloader_fn(self, create_dataloader_fn):
146166 # the new function
147167 self .has_custom_dataloader_fn = True
148168 self .create_dataloader = create_dataloader_fn
169+ == == == =
170+ )
171+ >> > >> >> 4 ca4993a (Conditions refactoring (#758))
0 commit comments