@@ -34,10 +34,7 @@ def __init__(self, **kwargs):
3434 """
3535 super ().__init__ ()
3636 self .data = self .store_data (** kwargs )
37- < << << << HEAD
3837 self .has_custom_dataloader_fn = False
39- == == == =
40- >> >> >> > 4 ca4993a (Conditions refactoring (#758))
4138
4239 @property
4340 def problem (self ):
@@ -90,12 +87,8 @@ def automatic_batching_collate_fn(cls, batch):
9087 if not batch :
9188 return {}
9289 instance_class = batch [0 ].__class__
93- < << << << HEAD
9490 batch = instance_class .create_batch (batch )
9591 return batch
96- == == == =
97- return instance_class .create_batch (batch )
98- >> >> >> > 4 ca4993a (Conditions refactoring (#758))
9992
10093 @staticmethod
10194 def collate_fn (batch , condition ):
@@ -113,15 +106,11 @@ def collate_fn(batch, condition):
113106 return data
114107
115108 def create_dataloader (
116- << << << < HEAD
117109 self ,
118110 dataset ,
119111 batch_size ,
120112 automatic_batching ,
121113 ** kwargs ,
122- == == == =
123- self , dataset , batch_size , shuffle , automatic_batching
124- > >> >> >> 4 ca4993a (Conditions refactoring (#758))
125114 ):
126115 """
127116 Create a DataLoader for the condition.
@@ -132,23 +121,14 @@ def create_dataloader(
132121 :rtype: torch.utils.data.DataLoader
133122 """
134123 if batch_size == len (dataset ):
135- << < << < < HEAD
136124 return DummyDataloader (dataset )
137125 return DataLoader (
138126 dataset = dataset ,
139- == == == =
140- pass # will be updated in the near future
141- return DataLoader (
142- dataset = dataset ,
143- batch_size = batch_size ,
144- shuffle = shuffle ,
145- >> >> >> > 4 ca4993a (Conditions refactoring (#758))
146127 collate_fn = (
147128 partial (self .collate_fn , condition = self )
148129 if not automatic_batching
149130 else self .automatic_batching_collate_fn
150131 ),
151- << << << < HEAD
152132 batch_size = batch_size ,
153133 ** kwargs ,
154134 )
@@ -166,6 +146,3 @@ def switch_dataloader_fn(self, create_dataloader_fn):
166146 # the new function
167147 self .has_custom_dataloader_fn = True
168148 self .create_dataloader = create_dataloader_fn
169- == == == =
170- )
171- >> > >> >> 4 ca4993a (Conditions refactoring (#758))
0 commit comments